Co-authored-by: gongweibao <gognweibao@baidu.com>
This commit is contained in:
gongweibao
2026-03-04 21:55:31 +08:00
committed by GitHub
parent 5c8f5184d9
commit ddb06ff83f
306 changed files with 40627 additions and 34418 deletions
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,18 +19,15 @@
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/interleaved_numeric_conversion.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// We need to distinguish here, since we want volta support. It is too much effort
// to write shared memory iterators that are probably needed for volta to function
// properly. As a result, we allow converters both after the LDG (for volta) and after
// the LDS for Turing+.
// We need to distinguish here, since we want volta support. It is too much
// effort to write shared memory iterators that are probably needed for volta to
// function properly. As a result, we allow converters both after the LDG (for
// volta) and after the LDS for Turing+.
template <
/// Iterator for B matrix in global memory
typename IteratorB,
@@ -38,9 +35,7 @@ template <
typename MmaOperator,
/// Math operation perform by warp level operator
typename MathOperator>
struct SetConverters
{
};
struct SetConverters {};
// Dequantize after LDG, so set transforms accordingly
template <
@@ -48,14 +43,16 @@ template <
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd>
{
using TransformAfterLDG
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element, IteratorB::Fragment::kElements>;
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAdd> {
using TransformAfterLDG = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS = NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB, MmaOperator::FragmentB::kElements>;
using TransformAfterLDS =
NumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename MmaOperator::ArchMmaOperator::ElementB,
MmaOperator::FragmentB::kElements>;
};
// Dequantize after LDS, so set transforms accordingly
@@ -65,14 +62,18 @@ template <
typename IteratorB,
/// Mma Policy
typename MmaOperator>
struct SetConverters<IteratorB, MmaOperator, arch::OpMultiplyAddDequantizeInterleavedBToA>
{
using TransformAfterLDG = NumericArrayConverter<typename IteratorB::Element, typename IteratorB::Element,
IteratorB::Fragment::kElements>;
struct SetConverters<IteratorB,
MmaOperator,
arch::OpMultiplyAddDequantizeInterleavedBToA> {
using TransformAfterLDG =
NumericArrayConverter<typename IteratorB::Element,
typename IteratorB::Element,
IteratorB::Fragment::kElements>;
using TransformAfterLDS
= FastInterleavedAndBiasedNumericArrayConverter<typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element, MmaOperator::FragmentB::kElements>;
using TransformAfterLDS = FastInterleavedAndBiasedNumericArrayConverter<
typename MmaOperator::ArchMmaOperator::ElementB,
typename TransformAfterLDG::result_type::Element,
MmaOperator::FragmentB::kElements>;
};
////////////////////////////////////////////////////////////////////////////////
@@ -120,6 +121,6 @@ template <
typename Enable = void>
struct DqMma;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,49 +27,77 @@
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsMultistage;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment>
struct DefaultScaleIteratorsMultistage<
MmaShape,
Element,
Layout,
QuantOp,
Alignment,
std::enable_if_t<isFinegrained(QuantOp)>> {
using IteratorScale =
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
Element,
Layout,
0,
Alignment>;
using SmemIteratorScale = IteratorScale;
using SmemIteratorScale = IteratorScale;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsMultistage<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
// ThreadMap for scale iterator
static_assert((MmaShape::kN % Alignment) == 0, "");
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment>
struct DefaultScaleIteratorsMultistage<
MmaShape,
Element,
Layout,
QuantOp,
Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>> {
// ThreadMap for scale iterator
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
private:
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment,
Alignment>;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
Element,
Layout,
0,
IteratorScaleThreadMap,
Alignment>;
using SmemIteratorScale = IteratorScale;
using SmemIteratorScale = IteratorScale;
};
////////////////////////////////////////////////////////////////////////////////
@@ -111,69 +139,133 @@ template <
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator_,
SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 &&
!layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value ||
platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, std::max(kStages, 3),
Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
ElementB,
LayoutB,
0,
ThreadMapB,
AccessTypeB>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, ElementB, LayoutB, 0, ThreadMapB,
AccessTypeB>;
using ScaleIterators =
DefaultScaleIteratorsMultistage<typename MmaCore::Shape,
ElementScale,
LayoutScale,
OperatorInfo::QuantOp,
kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementScale,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
OperatorInfo::QuantOp,
SharedMemoryClear>;
};
// Specialization to handle column major interleave B
@@ -214,89 +306,159 @@ template <
typename Operator_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator_, SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator_,
SharedMemoryClear,
typename platform::enable_if<(
ArchTag::kMinComputeCapability >= 80 &&
layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value ||
platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value
|| platform::is_same<ElementA, float_e4m3_t>::value,
"Element A must be fp16, fp8 or bf16");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
AccessTypeB>;
public:
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, AccessTypeB>;
using ScaleIterators =
DefaultScaleIteratorsMultistage<typename MmaCore::Shape,
ElementScale,
LayoutScale,
OperatorInfo::QuantOp,
kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsMultistage<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<
ElementScale,
ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
using Converter = FastInterleavedAndBiasedNumericArrayConverter<ElementScale, ElementB,
MmaCore::MmaPolicy::Operator::FragmentB::kElements>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, IteratorScale, SmemIteratorScale, ElementAccumulator, layout::RowMajor,
typename MmaCore::MmaPolicy, kStages, Converter, OperatorInfo::QuantOp, SharedMemoryClear>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
Converter,
OperatorInfo::QuantOp,
SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,58 +27,95 @@
#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h"
#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h"
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment,
typename Enable = void>
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment,
typename Enable = void>
struct DefaultScaleIteratorsPipelined;
// Fine grained iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<isFinegrained(QuantOp)>>
{
private:
using SmemScaleType = half_t;
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment>
struct DefaultScaleIteratorsPipelined<
MmaShape,
Element,
Layout,
QuantOp,
Alignment,
std::enable_if_t<isFinegrained(QuantOp)>> {
private:
using SmemScaleType = half_t;
public:
using IteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>, Element,
Layout, 0, Alignment>;
public:
using IteratorScale =
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
Element,
Layout,
0,
Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::FineGrainedScaleZeroIterator<cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType, Layout, 0, Alignment>;
using SmemIteratorScale =
cutlass::transform::threadblock::FineGrainedScaleZeroIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType,
Layout,
0,
Alignment>;
};
// Per column iterators
template <typename MmaShape, typename Element, typename Layout, WeightOnlyQuantOp QuantOp, int Alignment>
struct DefaultScaleIteratorsPipelined<MmaShape, Element, Layout, QuantOp, Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>>
{
static_assert((MmaShape::kN % Alignment) == 0, "");
template <typename MmaShape,
typename Element,
typename Layout,
WeightOnlyQuantOp QuantOp,
int Alignment>
struct DefaultScaleIteratorsPipelined<
MmaShape,
Element,
Layout,
QuantOp,
Alignment,
std::enable_if_t<!isFinegrained(QuantOp)>> {
static_assert((MmaShape::kN % Alignment) == 0, "");
private:
// ThreadMap for scale iterator
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment, Alignment>;
using SmemScaleType = half_t;
private:
// ThreadMap for scale iterator
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaShape::kN, 1>,
MmaShape::kN / Alignment,
Alignment>;
using SmemScaleType = half_t;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>,
Element, Layout, 0, IteratorScaleThreadMap, Alignment>;
public:
// Define iterators over tiles from the scale operand
using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
Element,
Layout,
0,
IteratorScaleThreadMap,
Alignment>;
using SmemIteratorScale
= cutlass::transform::threadblock::PredicatedTileIterator<cutlass::MatrixShape<1, MmaShape::kN>, SmemScaleType,
Layout, 0, IteratorScaleThreadMap, Alignment>;
using SmemIteratorScale =
cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<1, MmaShape::kN>,
SmemScaleType,
Layout,
0,
IteratorScaleThreadMap,
Alignment>;
};
////////////////////////////////////////////////////////////////////////////////
@@ -116,57 +153,110 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && !layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator_,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 &&
!layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(OperatorInfo::QuantOp ==
WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY,
"");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, LayoutB, ElementAccumulator, layout::RowMajor, OperatorClass, 2,
Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
ElementB,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, ElementB, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape,
ElementScale,
LayoutScale,
OperatorInfo::QuantOp,
kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS,
OperatorInfo::QuantOp>;
};
// Specialization to handle column major interleave B
@@ -203,82 +293,140 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator_>
struct DqMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementScale, LayoutScale, kAlignmentScale,
ElementAccumulator, layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2,
Operator_, SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 && layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type>
{
struct DqMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementScale,
LayoutScale,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator_,
SharedMemoryClearOption::kNone,
typename platform::enable_if<(
ArchTag::kMinComputeCapability < 80 &&
layout::IsColumnMajorTileInterleave<LayoutB>::value)>::type> {
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint8_t>::value ||
platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
static_assert(platform::is_same<ElementB, uint8_t>::value || platform::is_same<ElementB, uint4b_t>::value,
"Element B must be uint8 or uint4");
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
using OperatorInfo = arch::DetagOperator<Operator_>;
using Operator = typename OperatorInfo::Operator;
static constexpr bool DqAfterLDG =
platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::
conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
static constexpr bool DqAfterLDG = platform::is_same<arch::OpMultiplyAdd, Operator>::value;
using MmaCoreElementA = half_t;
using MmaCoreElementB = typename platform::conditional<DqAfterLDG, MmaCoreElementA, ElementB>::type;
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaCoreElementA,
LayoutA,
MmaCoreElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
2,
Operator>;
// Define the MmaCore components
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
MmaCoreElementA, LayoutA, MmaCoreElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor,
OperatorClass, 2, Operator>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
ElementA,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA>;
// Define iterators over tiles from the A operand
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, ElementA, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA>;
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
private:
static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int RowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), "");
static_assert(RowsPerTile == MmaCore::Shape::kK, "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement =
typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using OriginalThreadMap = typename MmaCore::IteratorThreadMapB;
using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement;
static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), "");
using GmemIteratorShape =
MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved,
MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow,
GmemIteratorShape::kColumn>,
OriginalThreadMap::kThreads,
layout::PitchLinearShape<
OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
using GmemIteratorShape
= MatrixShape<MmaCore::Shape::kK * ColumnsInterleaved, MmaCore::Shape::kN / ColumnsInterleaved>;
using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<GmemIteratorShape::kRow, GmemIteratorShape::kColumn>, OriginalThreadMap::kThreads,
layout::PitchLinearShape<OriginalWarpArrangement::kContiguous * ColumnsInterleaved,
OriginalWarpArrangement::kStrided / ColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
GmemIteratorShape,
ElementB,
layout::ColumnMajor,
0,
GmemThreadMapB,
kAlignmentB>;
public:
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<GmemIteratorShape, ElementB,
layout::ColumnMajor, 0, GmemThreadMapB, kAlignmentB>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale,
kAlignmentScale>;
// ThreadMap for scale iterator
static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, "");
using IteratorScaleThreadMap
= transform::PitchLinearStripminedThreadMap<layout::PitchLinearShape<MmaCore::Shape::kN, 1>,
MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape,
ElementScale,
LayoutScale,
OperatorInfo::QuantOp,
kAlignmentScale>;
using ScaleIterators = DefaultScaleIteratorsPipelined<typename MmaCore::Shape, ElementScale, LayoutScale,
OperatorInfo::QuantOp, kAlignmentScale>;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
// Define iterators over tiles from the scale operand
using IteratorScale = typename ScaleIterators::IteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale;
using Converters =
SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
using Converters = SetConverters<IteratorB, typename MmaCore::MmaPolicy::Operator, Operator>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, IteratorScale, SmemIteratorScale,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS, OperatorInfo::QuantOp>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
IteratorScale,
SmemIteratorScale,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
typename Converters::TransformAfterLDG,
typename Converters::TransformAfterLDS,
OperatorInfo::QuantOp>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,7 +27,8 @@ namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2)
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int8 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -49,34 +50,61 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
half_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2)
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight, mma pipelined (stage=2)
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -98,35 +126,61 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
half_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage
/// (stage>=3)
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int8 weight, mma multistage (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -152,36 +206,64 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
half_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage
/// (stage>=3)
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight, mma multistage (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -207,37 +289,65 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<half_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
half_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
using Mma = DqMma<half_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#ifdef ENABLE_FP8
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage
/// (stage>=3)
/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation
/// & int4 weight, mma multistage (stage>=3)
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -263,36 +373,65 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
struct DefaultMma<cutlass::float_e4m3_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<half_t>::value;
using Mma = DqMma<cutlass::float_e4m3_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
half_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
using Mma = DqMma<cutlass::float_e4m3_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, half_t,
layout::RowMajor, kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag,
ThreadblockShape, WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
#endif
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps
// avoid reg spills on large tile when not enough shared mem is present to do 3+
// stage
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -318,39 +457,86 @@ template <
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<half_t, LayoutA, kAlignmentA, half_t, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor,
arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
struct DefaultMma<half_t,
LayoutA,
kAlignmentA,
half_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
half_t,
LayoutA,
half_t,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
3,
Operator>;
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
half_t, LayoutA, half_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
half_t,
LayoutA,
1,
ThreadMapA,
AccessTypeA,
GatherA>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<half_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, half_t, LayoutA, 1, ThreadMapA, AccessTypeA,
GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
half_t,
LayoutB,
0,
ThreadMapB,
AccessTypeB,
GatherB>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<half_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, half_t, LayoutB, 0, ThreadMapB, AccessTypeB,
GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
/// Specialization for row-major output (OperatorClass TensorOp), fbf16
/// activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
@@ -373,26 +559,50 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
struct DefaultMma<cutlass::half_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
using Mma = DefaultWint2xMma<half_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
@@ -420,29 +630,55 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<half_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
struct DefaultMma<half_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
using Mma = DefaultWint2xMma<half_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -27,7 +27,8 @@ namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & bf16 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -55,40 +56,85 @@ template <
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator, false,
SharedMemoryClear, GatherA, GatherB>
{
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
bfloat16_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we
// convert before STS.
static constexpr bool arch_has_bf16_mma =
ArchTag::kMinComputeCapability >= 80;
using MmaElementA = typename platform::
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
using MmaElementB = typename platform::
conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
private:
// Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS.
static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80;
using MmaElementA = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
using MmaElementB = typename platform::conditional<arch_has_bf16_mma, bfloat16_t, half_t>::type;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
MmaElementA,
LayoutA,
MmaElementB,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, MmaElementA,
LayoutA, MmaElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 2, Operator>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>,
bfloat16_t,
LayoutA,
1,
typename MmaCore::IteratorThreadMapA,
kAlignmentA,
GatherA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kM, MmaCore::Shape::kK>, bfloat16_t, LayoutA, 1,
typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>,
bfloat16_t,
LayoutB,
0,
typename MmaCore::IteratorThreadMapB,
kAlignmentB,
GatherB>;
// Define iterators over tiles from the B operand
using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<MmaCore::Shape::kK, MmaCore::Shape::kN>, bfloat16_t, LayoutB, 0,
typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, IteratorB, typename MmaCore::SmemIteratorB, ElementAccumulator,
layout::RowMajor, typename MmaCore::MmaPolicy>;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaPipelined<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
IteratorB,
typename MmaCore::SmemIteratorB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy>;
};
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on
// large tile when not enough shared mem is present to do 3+ stage
// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps
// avoid reg spills on large tile when not enough shared mem is present to do 3+
// stage
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -114,40 +160,86 @@ template <
bool GatherA,
/// Gather operand B by using an index array
bool GatherB>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, bfloat16_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape, WarpShape, InstructionShape, 2, Operator,
false, SharedMemoryClear, GatherA, GatherB>
{
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
bfloat16_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
arch::Sm80,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator,
false,
SharedMemoryClear,
GatherA,
GatherB> {
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
bfloat16_t,
LayoutA,
bfloat16_t,
LayoutB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
3,
Operator>;
// Define the MmaCore components
// 3 is used on purpose here to trigger components for mma multistage
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape, bfloat16_t,
LayoutA, bfloat16_t, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, 3, Operator>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
bfloat16_t,
LayoutA,
1,
ThreadMapA,
AccessTypeA,
GatherA>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<bfloat16_t, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, bfloat16_t, LayoutA, 1, ThreadMapA,
AccessTypeA, GatherA>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>,
bfloat16_t,
LayoutB,
0,
ThreadMapB,
AccessTypeB,
GatherB>;
// Define iterators over tiles from the B operand
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using AccessTypeB = cutlass::Array<bfloat16_t, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kK, ThreadblockShape::kN>, bfloat16_t, LayoutB, 0, ThreadMapB,
AccessTypeB, GatherB>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape, IteratorA,
typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy, 2>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma =
cutlass::gemm::threadblock::MmaMultistage<typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
2>;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -169,34 +261,61 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
bfloat16_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -218,34 +337,61 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
bfloat16_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight
/// Specialization for row-major output (OperatorClass TensorOp), bf16
/// activation & int8 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -271,35 +417,64 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint8_t,
LayoutB,
kAlignmentB,
bfloat16_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint8_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight
/// Specialization for row-major output (OperatorClass TensorOp), fp16
/// activation & int4 weight
template <
/// Layout type for A matrix operand
typename LayoutA,
@@ -325,35 +500,64 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
private:
static constexpr int kAlignmentScale = 128 / sizeof_bits<bfloat16_t>::value;
using Mma = DqMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint4b_t,
LayoutB,
kAlignmentB,
bfloat16_t,
layout::RowMajor,
kAlignmentScale,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
using Mma = DqMma<bfloat16_t, LayoutA, kAlignmentA, uint4b_t, LayoutB, kAlignmentB, bfloat16_t, layout::RowMajor,
kAlignmentScale, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped pipelined matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
////////////////////////////////////////////////////////////////////////////////
/// Specialization for row-major output (OperatorClass TensorOp), fbf16 activation & int2 weight, mma multistage
/// Specialization for row-major output (OperatorClass TensorOp), fbf16
/// activation & int2 weight, mma multistage
template <
/// Layout type for A matrix operand
@@ -376,26 +580,50 @@ template <
typename InstructionShape,
/// Operation performed by GEMM
typename Operator>
struct DefaultMma<cutlass::bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, 2, Operator>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, 2, Operator>;
struct DefaultMma<cutlass::bfloat16_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator> {
private:
using Mma = DefaultWint2xMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
2,
Operator>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
template <
@@ -423,29 +651,55 @@ template <
int kStages,
/// Shared memory clear option
SharedMemoryClearOption SharedMemoryClear>
struct DefaultMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape, WarpShape, InstructionShape, kStages, Operator,
false, SharedMemoryClear>
{
private:
using Mma = DefaultWint2xMma<bfloat16_t, LayoutA, kAlignmentA, uint2b_t, LayoutB, kAlignmentB,
ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, ThreadblockShape,
WarpShape, InstructionShape, kStages, Operator, SharedMemoryClear>;
struct DefaultMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
false,
SharedMemoryClear> {
private:
using Mma = DefaultWint2xMma<bfloat16_t,
LayoutA,
kAlignmentA,
uint2b_t,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
arch::OpClassTensorOp,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear>;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
public:
// Define the MmaCore components
using MmaCore = typename Mma::MmaCore;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the A operand
using IteratorA = typename Mma::IteratorA;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define iterators over tiles from the B operand
using IteratorB = typename Mma::IteratorB;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = typename Mma::ThreadblockMma;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
@@ -66,10 +67,21 @@ template <
cutlass::arch::CacheOperation::Kind CacheOpA,
/// Cache operation of operand B
cutlass::arch::CacheOperation::Kind CacheOpB>
struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
layout::RowMajor, uint2b_t, layout::ColumnMajor,
ElementC_, LayoutC_, arch::OpClassTensorOp, Stages,
Operator_, false, CacheOpA, CacheOpB> {
struct DefaultMmaCore<Shape_,
WarpShape_,
InstructionShape_,
ElementA_,
layout::RowMajor,
uint2b_t,
layout::ColumnMajor,
ElementC_,
LayoutC_,
arch::OpClassTensorOp,
Stages,
Operator_,
false,
CacheOpA,
CacheOpB> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
@@ -104,7 +116,8 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
/// Size of a threadblock-scoped access of B
static constexpr int kMaxThreadsForB =
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) / kAccessSizeInBits;
(Shape::kK * Shape::kN * sizeof_bits<ElementB>::value) /
kAccessSizeInBits;
static constexpr int kThreadsForB =
kMaxThreadsForB > kThreads ? kThreads : kMaxThreadsForB;
@@ -129,11 +142,13 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
//
using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementA>::value, Shape::kK>;
sizeof_bits<ElementA>::value,
Shape::kK>;
// Shared memory layout
using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise<
sizeof_bits<ElementB>::value, Shape::kK>;
sizeof_bits<ElementB>::value,
Shape::kK>;
//
// Iterators to write to shared memory
@@ -141,26 +156,34 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
/// ThreadMap of iterator A
using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
layout::PitchLinearShape<Shape::kK, Shape::kM>,
kThreads,
layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kM, Shape::kK>, ElementA, SmemLayoutA, 0,
MatrixShape<Shape::kM, Shape::kK>,
ElementA,
SmemLayoutA,
0,
IteratorThreadMapA>;
/// ThreadMap of iterator B
using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreadsForB,
layout::PitchLinearShape<Shape::kK, Shape::kN>,
kThreadsForB,
layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = transform::threadblock::RegularTileAccessIterator<
MatrixShape<Shape::kK, Shape::kN>, ElementB, SmemLayoutB, 1,
MatrixShape<Shape::kK, Shape::kN>,
ElementB,
SmemLayoutB,
1,
IteratorThreadMapB>;
//
@@ -168,13 +191,23 @@ struct DefaultMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
using MmaTensorOp =
typename cutlass::gemm::warp::DefaultMmaTensorOp<WarpShape,
InstructionShape,
ElementA,
SmemLayoutA,
ElementB,
SmemLayoutB,
ElementC,
LayoutC,
Operator,
WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = MmaPolicy<MmaTensorOp, MatrixShape<0, 0>,
MatrixShape<0, 0>, WarpCount::kK>;
using MmaPolicy = MmaPolicy<MmaTensorOp,
MatrixShape<0, 0>,
MatrixShape<0, 0>,
WarpCount::kK>;
};
} // namespace threadblock
@@ -1,6 +1,6 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -31,47 +31,61 @@ namespace threadblock {
template <typename ThreadblockShape, typename ElementT, int GroupSize>
struct DefaultQuantParamsIterators {
private:
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
private:
static constexpr int kAlignment = 128 / sizeof_bits<ElementT>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
static constexpr int kColumns = ThreadblockShape::kN;
static constexpr int kRows =
(GroupSize == -1) ? 1
: (ThreadblockShape::kK + GroupSize - 1) / GroupSize;
static constexpr int kColumns = ThreadblockShape::kN;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment,
kAlignment>;
public:
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
MatrixShape<kRows, kColumns>, ElementT, layout::RowMajor, 0,
IteratorThreadMap, kAlignment>;
using SmemIterator = Iterator;
public:
using Iterator = cutlass::transform::threadblock::PredicatedTileIterator<
MatrixShape<kRows, kColumns>,
ElementT,
layout::RowMajor,
0,
IteratorThreadMap,
kAlignment>;
using SmemIterator = Iterator;
};
template <typename ThreadblockShape, int GroupSize>
struct DefaultQuantParamsIterators<ThreadblockShape, uint4b_t, GroupSize> {
private:
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
private:
static constexpr int kAlignment = 32 / sizeof_bits<uint4b_t>::value;
static_assert((ThreadblockShape::kN % kAlignment) == 0, "");
static constexpr int kRows =
(GroupSize == -1) ? 1 : (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
static constexpr int kColumns =
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
static constexpr int kRows =
(GroupSize == -1)
? 1
: (ThreadblockShape::kK + 2 * GroupSize - 1) / (2 * GroupSize);
static constexpr int kColumns =
(GroupSize == -1) ? ThreadblockShape::kN : ThreadblockShape::kN * 2;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment, kAlignment>;
using IteratorThreadMap = transform::PitchLinearStripminedThreadMap<
layout::PitchLinearShape<kColumns, kRows>,
kColumns / kAlignment,
kAlignment>;
public:
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
using Iterator = cutlass::transform::threadblock::PredicatedTileAccessIterator<
MatrixShape<kRows, kColumns>, uint4b_t, layout::RowMajor,
0, IteratorThreadMap, AccessType>;
public:
using AccessType = cutlass::Array<uint4b_t, kAlignment>;
using Iterator =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
MatrixShape<kRows, kColumns>,
uint4b_t,
layout::RowMajor,
0,
IteratorThreadMap,
AccessType>;
using SmemIterator = Iterator;
using SmemIterator = Iterator;
};
template <
@@ -142,105 +156,174 @@ template <
typename Operator,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear>
struct DefaultWint2xMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator,
layout::RowMajor, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape,
kStages, Operator, SharedMemoryClear>
{
public:
static_assert(platform::is_same<ElementA, half_t>::value || platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
struct DefaultWint2xMma<ElementA,
LayoutA,
kAlignmentA,
ElementB,
LayoutB,
kAlignmentB,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
ArchTag,
ThreadblockShape,
WarpShape,
InstructionShape,
kStages,
Operator,
SharedMemoryClear> {
public:
static_assert(platform::is_same<ElementA, half_t>::value ||
platform::is_same<ElementA, bfloat16_t>::value,
"Element A must be fp16 or bf16");
static_assert(platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint2b_t");
static_assert(platform::is_same<ElementB, uint2b_t>::value,
"Element B must be uint2b_t");
static_assert(platform::is_same<Operator, arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
static_assert(
platform::is_same<Operator,
arch::OpMultiplyAddDequantizeInterleavedBToA>::value,
"Mma multistage must dequantize after ldsm");
using ElementSuperScale = ElementA;
using ElementLocalScale = uint4b_t;
using ElementCodeScaleZp = float;
using ElementSuperScale = ElementA;
using ElementLocalScale = uint4b_t;
using ElementCodeScaleZp = float;
static constexpr int kGroupSize = 64;
static constexpr int kGroupSize = 64;
static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpA =
((sizeof_bits<ElementA>::value * kAlignmentA) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
static cutlass::arch::CacheOperation::Kind const CacheOpB =
((sizeof_bits<ElementB>::value * kAlignmentB) == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created
using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape, WarpShape, InstructionShape,
ElementA, LayoutA, ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, OperatorClass,
std::max(kStages, 3), Operator, false, CacheOpA, CacheOpB>;
// Define the MmaCore components
// Mma core does not depend on stages, so pass in at least 3 here to mma
// multistage pieces are created
using MmaCore =
typename cutlass::gemm::threadblock::DefaultMmaCore<ThreadblockShape,
WarpShape,
InstructionShape,
ElementA,
LayoutA,
ElementB,
layout::ColumnMajor,
ElementAccumulator,
layout::RowMajor,
OperatorClass,
std::max(kStages, 3),
Operator,
false,
CacheOpA,
CacheOpB>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>, ElementA, LayoutA, 1, ThreadMapA,
AccessTypeA>;
// Define iterators over tiles from the A operand
using ThreadMapA = typename MmaCore::IteratorThreadMapA;
using AccessTypeA = cutlass::Array<ElementA, kAlignmentA>;
using IteratorA =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
cutlass::MatrixShape<ThreadblockShape::kM, ThreadblockShape::kK>,
ElementA,
LayoutA,
1,
ThreadMapA,
AccessTypeA>;
private:
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved), "ThreadblockShape must be disivle by kColumnsInterleaved");
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
private:
static constexpr int kColumnsInterleaved = LayoutB::kColumnsInterleaved;
static constexpr int kRowsPerTile = LayoutB::kRowsPerTile;
static_assert(!(MmaCore::Shape::kN % kColumnsInterleaved),
"ThreadblockShape must be disivle by kColumnsInterleaved");
static_assert(kRowsPerTile == MmaCore::Shape::kK, "");
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
using ThreadMapB = typename MmaCore::IteratorThreadMapB;
using WarpArrangement = typename ThreadMapB::Detail::WarpThreadArrangement;
static_assert(!(WarpArrangement::kStrided % kColumnsInterleaved), "");
using IteratorShapeB = MatrixShape<
MmaCore::Shape::kK * kColumnsInterleaved, MmaCore::Shape::kN / kColumnsInterleaved>;
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
ThreadMapB::kThreads,
layout::PitchLinearShape<WarpArrangement::kContiguous * kColumnsInterleaved,
WarpArrangement::kStrided / kColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
using IteratorShapeB = MatrixShape<MmaCore::Shape::kK * kColumnsInterleaved,
MmaCore::Shape::kN / kColumnsInterleaved>;
using InterleavedThreadMapB = transform::PitchLinearWarpRakedThreadMap<
layout::PitchLinearShape<IteratorShapeB::kRow, IteratorShapeB::kColumn>,
ThreadMapB::kThreads,
layout::PitchLinearShape<WarpArrangement::kContiguous *
kColumnsInterleaved,
WarpArrangement::kStrided / kColumnsInterleaved>,
MmaCore::kAccessSizeInBits / sizeof_bits<ElementB>::value>;
public:
// Define iterators over tiles from the B operand
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator<
IteratorShapeB, ElementB, layout::ColumnMajor, 0, InterleavedThreadMapB,
AccessTypeB>;
public:
// Define iterators over tiles from the B operand
using AccessTypeB = cutlass::Array<ElementB, kAlignmentB>;
using IteratorB =
cutlass::transform::threadblock::PredicatedTileAccessIterator<
IteratorShapeB,
ElementB,
layout::ColumnMajor,
0,
InterleavedThreadMapB,
AccessTypeB>;
private:
// Define iterators over tiles from extra quant params for B operand
using IteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::Iterator;
using SmemIteratorSuperScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementSuperScale, -1>::SmemIterator;
private:
// Define iterators over tiles from extra quant params for B operand
using IteratorSuperScale =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementSuperScale,
-1>::Iterator;
using SmemIteratorSuperScale =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementSuperScale,
-1>::SmemIterator;
using IteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::Iterator;
using SmemIteratorLocalScale = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementLocalScale, kGroupSize>::SmemIterator;
using IteratorLocalScale =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementLocalScale,
kGroupSize>::Iterator;
using SmemIteratorLocalScale =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementLocalScale,
kGroupSize>::SmemIterator;
using IteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
using SmemIteratorCodeScaleZp = typename DefaultQuantParamsIterators<
ThreadblockShape, ElementCodeScaleZp, -1>::Iterator;
using IteratorCodeScaleZp =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementCodeScaleZp,
-1>::Iterator;
using SmemIteratorCodeScaleZp =
typename DefaultQuantParamsIterators<ThreadblockShape,
ElementCodeScaleZp,
-1>::Iterator;
public:
using QuantParamsAccessor = Wint2ParamsAccessor<
ElementA, ThreadblockShape, IteratorSuperScale, SmemIteratorSuperScale,
IteratorLocalScale, SmemIteratorLocalScale,
IteratorCodeScaleZp, SmemIteratorCodeScaleZp, kStages, kGroupSize>;
public:
using QuantParamsAccessor = Wint2ParamsAccessor<ElementA,
ThreadblockShape,
IteratorSuperScale,
SmemIteratorSuperScale,
IteratorLocalScale,
SmemIteratorLocalScale,
IteratorCodeScaleZp,
SmemIteratorCodeScaleZp,
kStages,
kGroupSize>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
typename MmaCore::Shape,
IteratorA, typename MmaCore::SmemIteratorA, MmaCore::kCacheOpA,
IteratorB, typename MmaCore::SmemIteratorB, MmaCore::kCacheOpB,
ElementAccumulator, layout::RowMajor, typename MmaCore::MmaPolicy,
kStages, QuantParamsAccessor, SharedMemoryClear>;
// Define the threadblock-scoped multistage matrix multiply
using ThreadblockMma = cutlass::gemm::threadblock::Wint2xMmaMultistage<
typename MmaCore::Shape,
IteratorA,
typename MmaCore::SmemIteratorA,
MmaCore::kCacheOpA,
IteratorB,
typename MmaCore::SmemIteratorB,
MmaCore::kCacheOpB,
ElementAccumulator,
layout::RowMajor,
typename MmaCore::MmaPolicy,
kStages,
QuantParamsAccessor,
SharedMemoryClear>;
};
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -47,30 +48,33 @@
////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
////////////////////////////////////////////////////////////////////////////////
// SFINAE trick so I can keep the same loop code for Volta and dispatch to the
// correct warp level mma. On volta, all data is stored to shared memory as FP16.
// correct warp level mma. On volta, all data is stored to shared memory as
// FP16.
template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C);
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma,
typename WarpMma::FragmentC& D,
typename WarpMma::FragmentA const& A,
typename WarpMma::FragmentB const& B,
typename WarpMma::FragmentC const& C,
int const warp_tileB_k_offset) {
warp_mma(D, A, B, C);
}
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
{
warp_mma(D, A, B, C, warp_tileB_k_offset);
CUTLASS_DEVICE void run_warp_mma(
WarpMma& warp_mma,
typename WarpMma::FragmentC& D,
typename WarpMma::TransformedFragmentA const& A,
typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C,
int const warp_tileB_k_offset) {
warp_mma(D, A, B, C, warp_tileB_k_offset);
}
////////////////////////////////////////////////////////////////////////////////
@@ -90,168 +94,169 @@ template <
WeightOnlyQuantOp DequantOp,
/// Used for partial specialization,
typename Enable = bool>
class DqMmaBase
{
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
class DqMmaBase {
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Policy describing tuning details
using Policy = Policy_;
///< Policy describing tuning details
using Policy = Policy_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
///< Type of the scale to be loaded
using ElementScale = ElementScale_;
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, "");
// Finegrained scales get streamed in via cp.async
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
// We always have scales.
static constexpr int ScaleElementsPerStage = Shape::kN;
// We sometimes have a bias
static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0;
// Finegrained scales get streamed in via cp.async
static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1;
// We always have scales.
static constexpr int ScaleElementsPerStage = Shape::kN;
// We sometimes have a bias
static constexpr int BiasElementsPerStage =
hasZero(DequantOp) ? Shape::kN : 0;
//
// Dependent types
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations =
(WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad =
Operator::IteratorB::InstructionShape::kRow /
Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB =
kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA =
TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB =
TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
//
// Dependent types
// Type definitions
//
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Shape of the A matrix operand in shared memory
using ShapeA =
MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow,
Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape describing the overall GEMM computed from shared memory
/// by each warp.
using WarpGemm = typename Policy::Operator::Shape;
/// Shape of the B matrix operand in shared memory
using ShapeB = MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow,
Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape describing the number of warps filling the CTA
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN, Shape::kK / WarpGemm::kK>;
/// Shape of the shared memory buffer for the scales for the B matrix.
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
/// Shape of the shared memory buffer for the biases of the B matrix.
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK);
static constexpr int kNumKIterationsPerWarpBLoad
= Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), "");
static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad;
/// Number of stages
static int const kStages = Stages;
/// Tensor reference to the A operand
using TensorRefA = TensorRef<typename Operator::ElementA, typename Operator::LayoutA>;
/// Tensor reference to the B operand
using TensorRefB = TensorRef<typename Operator::ElementB, typename Operator::LayoutB>;
//
// Nested structs
//
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage
{
public:
//
// Type definitions
//
/// Shape of the A matrix operand in shared memory
using ShapeA
= MatrixShape<Shape::kM + Policy::SmemPaddingA::kRow, Shape::kK * kStages + Policy::SmemPaddingA::kColumn>;
/// Shape of the B matrix operand in shared memory
using ShapeB
= MatrixShape<Shape::kK * kStages + Policy::SmemPaddingB::kRow, Shape::kN + Policy::SmemPaddingB::kColumn>;
/// Shape of the shared memory buffer for the scales for the B matrix.
using ShapeScale = MatrixShape<ScalebiasStages, ScaleElementsPerStage>;
/// Shape of the shared memory buffer for the biases of the B matrix.
using ShapeZero = MatrixShape<ScalebiasStages, BiasElementsPerStage>;
public:
//
// Data members
//
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
static typename Operator::LayoutA LayoutA()
{
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB()
{
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref()
{
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref()
{
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
public:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Buffer for A operand
AlignedBuffer<typename Operator::ElementA, ShapeA::kCount> operand_A;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
/// Buffer for B operand
AlignedBuffer<typename Operator::ElementB, ShapeB::kCount> operand_B;
public:
/// Construct from tensor references
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeScale::kCount> operand_scale;
/// Buffer to hold scales for threadblock
AlignedBuffer<ElementScale, ShapeZero::kCount> operand_zero;
public:
//
// Methods
//
/// Returns a layout object for the A matrix
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx)
, warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx)
{
static typename Operator::LayoutA LayoutA() {
return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});
}
/// Returns a layout object for the B matrix
CUTLASS_HOST_DEVICE
static typename Operator::LayoutB LayoutB() {
return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});
}
/// Returns a TensorRef to the A operand
CUTLASS_HOST_DEVICE
TensorRefA operand_A_ref() {
return TensorRefA{operand_A.data(), LayoutA()};
}
/// Returns a TensorRef to the B operand
CUTLASS_HOST_DEVICE
TensorRefB operand_B_ref() {
return TensorRefB{operand_B.data(), LayoutB()};
}
};
protected:
//
// Data members
//
/// Iterator to load a warp-scoped tile of A operand from shared memory
typename Operator::IteratorA warp_tile_iterator_A_;
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaBase(
///< Shared storage needed for internal use by threadblock-scoped GEMM
SharedStorage& shared_storage,
///< ID within the threadblock
int thread_idx,
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx),
warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -48,12 +49,9 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -102,9 +100,9 @@ template <
typename Enable = void>
class DqMmaMultistage;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h"
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -53,27 +54,27 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
@@ -98,9 +99,9 @@ template <
typename Enable = void>
class DqMmaPipelined;
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h"
#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h"
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -53,27 +54,27 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
@@ -94,393 +95,442 @@ template <
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
class DqMmaPipelined<Shape_,
IteratorA_,
SmemIteratorA_,
IteratorB_,
SmemIteratorB_,
IteratorScale_,
SmemIteratorScale_,
ElementC_,
LayoutC_,
Policy_,
TransformBAfterLDG_,
TransformBAfterLDS_,
QuantOp_,
std::enable_if_t<isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_,
Policy_,
typename SmemIteratorScale_::Element,
2,
QuantOp_> {
public:
///< Base class
using Base = DqMmaBase<Shape_,
Policy_,
typename SmemIteratorScale_::Element,
2,
QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Element, LayoutScale, 32, QuantOp>;
using Dequantizer =
warp::MmaTensorOpDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename SmemIteratorScale::Element,
LayoutScale,
32,
QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered
// pipeline)
static_assert((Base::kStages == 2),
"DqMmaPipelined requires kStages set to value 2");
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, "");
static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, "");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
using WarpFragmentZero = typename Dequantizer::FragmentZero;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
using WarpFragmentZero = typename Dequantizer::FragmentZero;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
static constexpr bool RequiresTileInterleave =
layout::IsColumnMajorTileInterleave<
typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave ||
(RequiresTileInterleave &&
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory
SmemIteratorScale smem_iterator_scale_;
/// Iterator to write threadblock-scoped tile of scale and zero operand to
/// shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< The group size for quantization
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size)
{
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal
///< use by threadblock-scoped GEMM
int const group_size, ///< The group size for quantization
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_dequantizer_(
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
{shared_storage.operand_zero.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
Base::WarpCount::kM,
lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_scale_(LayoutScale(Shape::kN),
shared_storage.operand_scale.data(),
shared_storage.operand_zero.data(),
{Base::kStages, Shape::kN},
thread_idx,
group_size) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale) {
using TransformScale =
NumericArrayConverter<typename SmemIteratorScale::Element,
typename FragmentScale::Element,
FragmentScale::kElements>;
FragmentScale tb_frag_scales;
FragmentScale tb_frag_zeros;
tb_frag_scales.clear();
tb_frag_zeros.clear();
TransformScale transformScale;
using FragmentElement = typename FragmentScale::Element;
auto gmem_scale_ptr = iterator_scale.get_scale();
auto gmem_zero_ptr = iterator_scale.get_zero();
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr) {
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
}
CUTLASS_DEVICE
void copy_scales_and_advance(IteratorScale& iterator_scale)
{
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
typename TransformScale::result_type tb_frag_scales_fp16 =
transformScale(tb_frag_scales);
typename TransformScale::result_type tb_frag_zeros_fp16;
if (gmem_zero_ptr != nullptr)
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
FragmentScale tb_frag_scales;
FragmentScale tb_frag_zeros;
tb_frag_scales.clear();
tb_frag_zeros.clear();
auto frag_scale_ptr_fp16 =
reinterpret_cast<typename SmemIteratorScale::Element*>(
&tb_frag_scales_fp16);
auto frag_zero_ptr_fp16 =
reinterpret_cast<typename SmemIteratorScale::Element*>(
&tb_frag_zeros_fp16);
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
TransformScale transformScale;
if (iterator_scale.valid()) {
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset,
frag_scale_ptr_fp16);
using FragmentElement = typename FragmentScale::Element;
auto gmem_scale_ptr = iterator_scale.get_scale();
auto gmem_zero_ptr = iterator_scale.get_zero();
arch::global_load<FragmentScale, sizeof(FragmentScale)>(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid());
if (gmem_zero_ptr != nullptr)
{
arch::global_load<FragmentScale, sizeof(FragmentScale)>(
tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid());
}
typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales);
typename TransformScale::result_type tb_frag_zeros_fp16;
if (gmem_zero_ptr != nullptr)
tb_frag_zeros_fp16 = transformScale(tb_frag_zeros);
auto frag_scale_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_scales_fp16);
auto frag_zero_ptr_fp16 = reinterpret_cast<typename SmemIteratorScale::Element*>(&tb_frag_zeros_fp16);
auto smem_scale_ptr = this->smem_iterator_scale_.get_scale();
auto smem_zero_ptr = this->smem_iterator_scale_.get_zero();
if (iterator_scale.valid())
{
auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_scale_ptr_fp16);
if (gmem_zero_ptr != nullptr)
{
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset, frag_zero_ptr_fp16);
}
}
if (iterator_scale.group_size_ == 64)
{
iterator_scale.add_tile_offset({1, 0});
}
else if (iterator_scale.group_size_ == 128)
{
if constexpr (Shape::kK == 128)
{
iterator_scale.add_tile_offset({1, 0});
}
else if constexpr (Shape::kK == 64)
{
if (iterator_scale.row_groupsize64_ & 0x1)
{
iterator_scale.add_tile_offset({1, 0});
}
}
else
{
static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128");
}
}
iterator_scale.row_groupsize64_++;
this->smem_iterator_scale_.add_tile_offset({1, 0});
if (gmem_zero_ptr != nullptr) {
smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr);
arch::shared_store<sizeof(FragmentScale)>(smem_offset,
frag_zero_ptr_fp16);
}
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
if (iterator_scale.group_size_ == 64) {
iterator_scale.add_tile_offset({1, 0});
} else if (iterator_scale.group_size_ == 128) {
if constexpr (Shape::kK == 128) {
iterator_scale.add_tile_offset({1, 0});
} else if constexpr (Shape::kK == 64) {
if (iterator_scale.row_groupsize64_ & 0x1) {
iterator_scale.add_tile_offset({1, 0});
}
} else {
static_assert(Shape::kK == 0,
"Unsupported k tile shape, can only be 64 or 128");
}
}
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
iterator_scale.row_groupsize64_++;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
this->smem_iterator_scale_.add_tile_offset({1, 0});
}
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale
iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum) { ///< source accumulator tile
// Perform accumulation in the 'd' output operand
accum = src_accum;
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
using TransformA = NumericArrayConverter<typename WarpFragmentA::Element,
typename FragmentA::Element,
FragmentA::kElements>;
tb_frag_A.clear();
tb_frag_B.clear();
// These transforms are mainly to handle when we have bfloat activations and
// weights in GMEM and want to issue HMMA on architectures older than
// Ampere. We will convert to FP16 before STS.
TransformA transformA;
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
// Perform accumulation in the 'd' output operand
accum = src_accum;
++iterator_A;
++iterator_B;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
tb_frag_A.clear();
tb_frag_B.clear();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
copy_scales_and_advance(iterator_scale);
++iterator_A;
++iterator_B;
__syncthreads();
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentScale warp_frag_scales;
WarpFragmentZero warp_frag_zero;
++this->smem_iterator_A_;
++this->smem_iterator_B_;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
copy_scales_and_advance(iterator_scale);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
__syncthreads();
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
WarpFragmentScale warp_frag_scales;
WarpFragmentZero warp_frag_zero;
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
iterator_scale.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterationsForB,
0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
warp_dequantizer_.add_pointer_offset(Shape::kN);
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
iterator_scale.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
copy_scales_and_advance(iterator_scale);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
iterator_scale.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_.add_pointer_offset(Shape::kN);
int const warp_tileB_k_compute_offset =
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset =
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate
// the load for the next fragment.
if (warp_tileB_k_compute_offset ==
Base::kNumKIterationsPerWarpBLoad - 1) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
copy_scales_and_advance(iterator_scale);
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
iterator_scale.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(
converted_frag_B, warp_frag_scales, warp_frag_zero);
run_warp_mma(warp_mma,
accum,
warp_frag_A[warp_mma_k % 2],
converted_frag_B,
accum,
warp_tileB_k_compute_offset);
}
// Load the scales needed for the next tile iteration
warp_dequantizer_.load(warp_frag_scales, warp_frag_zero);
// Update internal pointer to the set of scales in shared memory
warp_dequantizer_.add_pointer_offset(Shape::kN);
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -53,27 +54,27 @@
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass
{
namespace gemm
{
namespace threadblock
{
namespace cutlass {
namespace gemm {
namespace threadblock {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Iterates over tiles of A operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorA_,
/// Iterates over tiles of A operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
typename SmemIteratorA_,
/// Iterates over tiles of B operand in global memory
// (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
// (concept: ReadableTileIterator | ForwardTileIterator |
// MaskedTileIterator)
typename IteratorB_,
/// Iterates over tiles of B operand in shared memory
/// (concept: WriteableTileIterator | RandomAccessTileIterator)
@@ -94,306 +95,361 @@ template <
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_>
class DqMmaPipelined<Shape_, IteratorA_, SmemIteratorA_, IteratorB_, SmemIteratorB_, IteratorScale_, SmemIteratorScale_,
ElementC_, LayoutC_, Policy_, TransformBAfterLDG_, TransformBAfterLDS_, QuantOp_,
std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>
{
public:
///< Base class
using Base = DqMmaBase<Shape_, Policy_, typename SmemIteratorScale_::Element, 2, QuantOp_>;
class DqMmaPipelined<Shape_,
IteratorA_,
SmemIteratorA_,
IteratorB_,
SmemIteratorB_,
IteratorScale_,
SmemIteratorScale_,
ElementC_,
LayoutC_,
Policy_,
TransformBAfterLDG_,
TransformBAfterLDS_,
QuantOp_,
std::enable_if_t<!isFinegrained(QuantOp_)>>
: public DqMmaBase<Shape_,
Policy_,
typename SmemIteratorScale_::Element,
2,
QuantOp_> {
public:
///< Base class
using Base = DqMmaBase<Shape_,
Policy_,
typename SmemIteratorScale_::Element,
2,
QuantOp_>;
using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using Shape =
Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<>
using IteratorA =
IteratorA_; ///< Iterates over tiles of A operand in global memory
using IteratorB =
IteratorB_; ///< Iterates over tiles of B operand in global memory
using ElementC = ElementC_; ///< Data type of accumulator matrix
using LayoutC = LayoutC_; ///< Layout of accumulator matrix
using Policy = Policy_; ///< Policy describing tuning details
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using IteratorScale = IteratorScale_;
using ElementScale = typename IteratorScale::Element;
using LayoutScale = typename IteratorScale::Layout;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using SmemIteratorA = SmemIteratorA_;
using SmemIteratorB = SmemIteratorB_;
using SmemIteratorScale = SmemIteratorScale_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
using TransformBAfterLDG = TransformBAfterLDG_;
using TransformBAfterLDS = TransformBAfterLDS_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
//
// Dependent types
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
/// Warp-level Mma
using Operator = typename Policy::Operator;
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
using Dequantizer = warp::MmaTensorOpDequantizer<
Operator,
typename Base::WarpGemm,
Operand::kB,
typename SmemIteratorScale::Fragment::Element,
LayoutScale,
32,
QuantOp>;
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered
// pipeline)
static_assert((Base::kStages == 2),
"DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave =
layout::IsColumnMajorTileInterleave<
typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave ||
(RequiresTileInterleave &&
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared
/// memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(
typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by
///< threadblock-scoped GEMM
int const
group_size, ///< Will not be used, just to adapt to finegrained
///< modifications and make the compilation successful.
///< Because DqMmaPipelined is only enabled for sm<80, so
///< even if this argument is not added, it does not
///< affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
warp_dequantizer_(
{shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
Base::WarpCount::kM,
lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
smem_iterator_scale_(LayoutScale(Shape::kN),
shared_storage.operand_scale.data(),
{1, Shape::kN},
thread_idx) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset(
{warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset(
{Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale
iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum) { ///< source accumulator tile
//
// Dependent types
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA = NumericArrayConverter<typename WarpFragmentA::Element,
typename FragmentA::Element,
FragmentA::kElements>;
using TransformScale =
NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element,
FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and
// weights in GMEM and want to issue HMMA on architectures older than
// Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math
// instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER*
// issuing shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
/// Fragment of operand A loaded from global memory
using FragmentA = typename IteratorA::Fragment;
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations) {
//
// Loop over GEMM K dimension
//
/// Fragment of operand B loaded from global memory
using FragmentB = typename IteratorB::Fragment;
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
// Load warp-level tiles from shared memory, wrapping to k offset if
// this is the last group as the case may be.
/// Fragment of operand Scale loaded from global memory;
using FragmentScale = typename IteratorScale::Fragment;
if (warp_mma_k == Base::kWarpGemmIterations - 1) {
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
/// Fragment of accumulator tile
using FragmentC = typename Policy::Operator::FragmentC;
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
/// Warp-level Mma
using Operator = typename Policy::Operator;
__syncthreads();
/// Obtain the arch tag from the warp-level operator
using ArchTag = typename Policy::Operator::ArchTag;
++this->smem_iterator_A_;
++this->smem_iterator_B_;
using Dequantizer = warp::MmaTensorOpDequantizer<Operator, typename Base::WarpGemm, Operand::kB,
typename SmemIteratorScale::Fragment::Element, LayoutScale, 32, QuantOp>;
// Add negative offsets to return iterators to the 'start' of the
// circular buffer in shared memory
if (smem_write_stage_idx == 1) {
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpGemmIterationsForB,
0});
}
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
/// Complex transform on B operand
static ComplexTransform const kTransformB = Operator::kTransformB;
// staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline)
static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2");
private:
using WarpFragmentA = typename Operator::FragmentA;
using WarpFragmentB = typename Operator::FragmentB;
Dequantizer warp_dequantizer_;
using ElementA = typename IteratorA::Element;
using ElementB = typename IteratorB::Element;
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool RequiresTileInterleave
= layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
protected:
/// Iterator to write threadblock-scoped tile of A operand to shared memory
SmemIteratorA smem_iterator_A_;
/// Iterator to write threadblock-scoped tile of B operand to shared memory
SmemIteratorB smem_iterator_B_;
/// Iterator to write threadblock-scoped tile of scale operand to shared memory
SmemIteratorScale smem_iterator_scale_;
public:
/// Construct from tensor references
CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock
int warp_idx, ///< ID of warp
int lane_idx ///< ID of each thread within a warp
)
: Base(shared_storage, thread_idx, warp_idx, lane_idx)
, warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)},
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx)
, smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx)
, smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx)
, smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx)
{
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
// _n: the warp's position within the threadblock along the N dimension
// _k: the warp's position within the threadblock along the K dimension
int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
// Add per-warp offsets in units of warp-level tiles
this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n});
}
/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop
FragmentC& accum, ///< destination accumulator tile
IteratorA iterator_A, ///< iterator over A operand in global memory
IteratorB iterator_B, ///< iterator over B operand in global memory
IteratorScale iterator_scale, ///< iterator over scale operand in global memory
FragmentC const& src_accum)
{ ///< source accumulator tile
//
// Prologue
//
TransformBAfterLDG ldg_converter;
TransformBAfterLDS lds_converter;
using TransformA
= NumericArrayConverter<typename WarpFragmentA::Element, typename FragmentA::Element, FragmentA::kElements>;
using TransformScale = NumericArrayConverter<typename SmemIteratorScale::Fragment::Element,
typename FragmentScale::Element, FragmentScale::kElements>;
// These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want
// to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS.
TransformA transformA;
TransformScale transformScale;
// Perform accumulation in the 'd' output operand
accum = src_accum;
FragmentA tb_frag_A;
FragmentB tb_frag_B;
FragmentScale tb_frag_scales;
using WarpFragmentScale = typename Dequantizer::FragmentScale;
WarpFragmentScale warp_frag_scales;
tb_frag_A.clear();
tb_frag_B.clear();
tb_frag_scales.clear();
// The last kblock is loaded in the prolog
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
iterator_scale.load(tb_frag_scales);
++iterator_A;
++iterator_B;
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
this->smem_iterator_scale_.store(transformScale(tb_frag_scales));
++this->smem_iterator_A_;
++this->smem_iterator_B_;
__syncthreads();
warp_dequantizer_.load(warp_frag_scales);
// Pair of fragments used to overlap shared memory loads and math instructions
WarpFragmentA warp_frag_A[2];
WarpFragmentB warp_frag_B[2];
this->warp_tile_iterator_A_.set_kgroup_index(0);
this->warp_tile_iterator_B_.set_kgroup_index(0);
this->warp_tile_iterator_A_.load(warp_frag_A[0]);
this->warp_tile_iterator_B_.load(warp_frag_B[0]);
++this->warp_tile_iterator_A_;
++this->warp_tile_iterator_B_;
Operator warp_mma;
int smem_write_stage_idx = 1;
// Avoid reading out of bounds
iterator_A.clear_mask(gemm_k_iterations <= 1);
iterator_B.clear_mask(gemm_k_iterations <= 1);
// Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
// shared memory loads (which have the tighest latency requirement).
//
// Mainloop
//
// Note: The main loop does not support Base::kWarpGemmIterations == 2.
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > 0; --gemm_k_iterations)
{
//
// Loop over GEMM K dimension
//
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
// as the case may be.
if (warp_mma_k == Base::kWarpGemmIterations - 1)
{
// Write fragments to shared memory
this->smem_iterator_A_.store(transformA(tb_frag_A));
this->smem_iterator_B_.store(ldg_converter(tb_frag_B));
__syncthreads();
++this->smem_iterator_A_;
++this->smem_iterator_B_;
// Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
if (smem_write_stage_idx == 1)
{
this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset(
{0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0});
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0)
{
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B
= lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(
warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset);
}
smem_write_stage_idx ^= 1;
}
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
int const warp_tileB_k_compute_offset =
warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
int const warp_tileB_k_load_offset =
warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate
// the load for the next fragment.
if (warp_tileB_k_compute_offset ==
Base::kNumKIterationsPerWarpBLoad - 1) {
this->warp_tile_iterator_B_.set_kgroup_index(
(warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB);
this->warp_tile_iterator_B_.load(
warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]);
++this->warp_tile_iterator_B_;
}
if (warp_mma_k == 0) {
iterator_A.load(tb_frag_A);
iterator_B.load(tb_frag_B);
++iterator_A;
++iterator_B;
// Avoid reading out of bounds if this was the last loop iteration
iterator_A.clear_mask(gemm_k_iterations <= 2);
iterator_B.clear_mask(gemm_k_iterations <= 2);
}
typename TransformBAfterLDS::result_type converted_frag_B =
lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]);
warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales);
run_warp_mma(warp_mma,
accum,
warp_frag_A[warp_mma_k % 2],
converted_frag_B,
accum,
warp_tileB_k_compute_offset);
}
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -66,7 +66,7 @@ template <
/// Size of extra quantized params
typename QuantParamsShape>
class Wint2xMmaBase {
public:
public:
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
@@ -85,9 +85,9 @@ public:
using WarpGemm = typename Policy::Operator::Shape;
/// Shape describing the number of warps filling the CTA
using WarpCount =
GemmShape<Shape::kM / WarpGemm::kM, Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
using WarpCount = GemmShape<Shape::kM / WarpGemm::kM,
Shape::kN / WarpGemm::kN,
Shape::kK / WarpGemm::kK>;
/// Number of warp-level GEMM operations
static int const kWarpGemmIterations =
@@ -95,7 +95,8 @@ public:
/// Number of warp-level GEMM operations per load for B
static constexpr int kWarpGemmIterationsPerLoadForB =
Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK;
Operator::IteratorB::InstructionShape::kRow /
Operator::InstructionShape::kK;
static_assert(!(kWarpGemmIterations % kWarpGemmIterationsPerLoadForB), "");
static constexpr int kWarpLoadIterationsForB =
@@ -125,7 +126,7 @@ public:
/// Shared storage object needed by threadblock-scoped GEMM
class SharedStorage {
public:
public:
//
// Type definitions
//
@@ -142,7 +143,7 @@ public:
/// Shape of all quant params in shared memory
using QuantParamsShapeB = QuantParamsShape;
public:
public:
//
// Data members
//
@@ -156,7 +157,7 @@ public:
/// Buffer for extra quant params of B operand
AlignedBuffer<uint8_t, QuantParamsShapeB::kCount> operand_quant_params_B;
public:
public:
//
// Methods
//
@@ -186,7 +187,7 @@ public:
}
};
protected:
protected:
//
// Data members
//
@@ -197,7 +198,7 @@ protected:
/// Iterator to load a warp-scoped tile of B operand from shared memory
typename Operator::IteratorB warp_tile_iterator_B_;
public:
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaBase(
@@ -215,8 +216,8 @@ public:
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
} // namespace threadblock
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
@@ -1,12 +1,12 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
@@ -18,14 +18,15 @@
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
@@ -91,11 +92,17 @@ template <
typename QuantParamsAccessor_,
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone>
class Wint2xMmaMultistage :
public Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape> {
public:
class Wint2xMmaMultistage
: public Wint2xMmaBase<Shape_,
Policy_,
Stages,
typename QuantParamsAccessor_::QuantParamsShape> {
public:
///< Base class
using Base = Wint2xMmaBase<Shape_, Policy_, Stages, typename QuantParamsAccessor_::QuantParamsShape>;
using Base = Wint2xMmaBase<Shape_,
Policy_,
Stages,
typename QuantParamsAccessor_::QuantParamsShape>;
///< Size of the Gemm problem - concept: gemm::GemmShape<>
using Shape = Shape_;
///< Iterates over tiles of A operand in global memory
@@ -133,17 +140,19 @@ public:
/// Minimum architecture is Sm80 to support cp.async
using ArchTag = arch::Sm80;
//using LayoutScale = typename QuantParamsAccessor::IteratorSuperScale::Layout;
// using LayoutScale = typename
// QuantParamsAccessor::IteratorSuperScale::Layout;
using LayoutScale = layout::RowMajor;
using WarpTransformedFragmentB = typename Operator::TransformedFragmentB;
using WarpDequantizer =
warp::MmaTensorOpWin2xDequantizer<Operator,
typename Base::WarpGemm,
Operand::kB,
typename WarpTransformedFragmentB::Element,
LayoutScale,
QuantParamsAccessor::kGroupSize>;
static_assert(sizeof(WarpDequantizer) > 0, "WarpDequantizer template instantiation failed");
using WarpDequantizer = warp::MmaTensorOpWin2xDequantizer<
Operator,
typename Base::WarpGemm,
Operand::kB,
typename WarpTransformedFragmentB::Element,
LayoutScale,
QuantParamsAccessor::kGroupSize>;
static_assert(sizeof(WarpDequantizer) > 0,
"WarpDequantizer template instantiation failed");
/// Complex transform on A operand
static ComplexTransform const kTransformA = Operator::kTransformA;
@@ -153,7 +162,6 @@ public:
/// Internal structure exposed for introspection.
struct Detail {
/// Number of cp.async instructions to load one stage of operand A
static int const AsyncCopyIterationsPerStageA =
IteratorA::ThreadMap::Iterations::kCount;
@@ -167,24 +175,25 @@ public:
/// Number of cp.async instructions to load on group of operand A
static int const kAccessesPerGroupA =
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
(AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
/// Number of cp.async instructions to load on group of operand B
static int const kAccessesPerGroupB =
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations;
(AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) /
Base::kWarpGemmIterations;
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical
// accuracy, where each mainloop iteration first accumulates into a temporary
// set of freshly-cleared accumulators, which are subsequently added to the
// final accumulator set.
static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation<Operator>::value;
// Optional staged-accumulation (e.g., tf32x3 kernels) for improved
// numerical accuracy, where each mainloop iteration first accumulates into
// a temporary set of freshly-cleared accumulators, which are subsequently
// added to the final accumulator set.
static bool const kStagedAccumulation =
arch::detail::UseStagedAccumulation<Operator>::value;
};
private:
// Structure encapsulating pipeline state live from one iteration to the next
struct PipeState {
using WarpLoadedFragmentA = typename Operator::FragmentA;
using WarpLoadedFragmentB = typename Operator::FragmentB;
using WarpTransformedFragmentA = typename Operator::TransformedFragmentA;
@@ -197,10 +206,12 @@ public:
/// Temporary accumulator to facilitate staged-accumulation
FragmentC tmp_accum_;
/// Pair of A fragments used to overlap shared memory loads and math instructions
/// Pair of A fragments used to overlap shared memory loads and math
/// instructions
WarpTransformedFragmentA warp_frag_A_[2];
/// Pair of B fragments used to overlap shared memory loads and math instructions
/// Pair of B fragments used to overlap shared memory loads and math
/// instructions
WarpLoadedFragmentB warp_loaded_frag_B_;
WarpTransformedFragmentB warp_frag_B_[2];
@@ -218,12 +229,14 @@ public:
using LayoutDetailsForB = kernel::LayoutDetailsB<ElementA, ElementB, ArchTag>;
static constexpr bool IsTileInterleaveLayout =
layout::IsColumnMajorTileInterleave<typename LayoutDetailsForB::Layout>::value;
static_assert(!IsTileInterleaveLayout || (IsTileInterleaveLayout && (Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
layout::IsColumnMajorTileInterleave<
typename LayoutDetailsForB::Layout>::value;
static_assert(!IsTileInterleaveLayout ||
(IsTileInterleaveLayout &&
(Shape::kK == LayoutDetailsForB::ThreadblockK)),
"Layout K must match threadblockK");
private:
//
// Data members
//
@@ -249,8 +262,7 @@ public:
/// Shared memory read stage index
int smem_read_stage_idx_;
public:
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2xMmaMultistage(
@@ -261,19 +273,24 @@ public:
///< ID of warp
int warp_idx,
///< ID of each thread within a warp
int lane_idx
) : Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(), thread_idx, warp_idx, lane_idx),
warp_dequantizer_(quant_params_accessor_B_.super_scale_ref(),
quant_params_accessor_B_.local_scale_ref(),
quant_params_accessor_B_.code_scale_ref(),
quant_params_accessor_B_.code_zp_ref(),
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0)
{
int lane_idx)
: Base(shared_storage, thread_idx, warp_idx, lane_idx),
smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx),
quant_params_accessor_B_(shared_storage.operand_quant_params_B.data(),
thread_idx,
warp_idx,
lane_idx),
warp_dequantizer_(
quant_params_accessor_B_.super_scale_ref(),
quant_params_accessor_B_.local_scale_ref(),
quant_params_accessor_B_.code_scale_ref(),
quant_params_accessor_B_.code_zp_ref(),
(warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) /
Base::WarpCount::kM,
lane_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {
// Compute warp location within threadblock tile by mapping the warp_id to
// three coordinates:
// _m: the warp's position within the threadblock along the M dimension
@@ -295,22 +312,26 @@ public:
/// Advance shared memory read-iterators to the next stage
CUTLASS_DEVICE
void advance_smem_read_stage()
{
void advance_smem_read_stage() {
++smem_read_stage_idx_;
if (smem_read_stage_idx_ == Base::kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpLoadIterationsForB, 0});
this->warp_tile_iterator_A_.add_tile_offset(
{0,
-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
this->warp_tile_iterator_B_.add_tile_offset(
{-Base::kStages * Policy::kPartitionsK *
Base::kWarpLoadIterationsForB,
0});
smem_read_stage_idx_ = 0;
}
}
/// Advance global memory read-iterators and shared memory write-iterators to the stage
/// Advance global memory read-iterators and shared memory write-iterators to
/// the stage
CUTLASS_DEVICE
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B)
{
void advance_smem_write_stage(IteratorA &iterator_A, IteratorB &iterator_B) {
// Advance global iterators
iterator_A.add_tile_offset({0, 1});
iterator_B.add_tile_offset({1, 0});
@@ -395,7 +416,9 @@ public:
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads) ? iterator_B.valid() : false;
bool is_valid = (threadIdx.x < IteratorB::ThreadMap::kThreads)
? iterator_B.valid()
: false;
if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
@@ -429,10 +452,9 @@ public:
for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_A.get();
int const kSrcBytes =
sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
int const kSrcBytes = sizeof_bits<typename IteratorA::Element>::value *
IteratorA::ThreadMap::kElementsPerAccess /
IteratorA::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
dst_ptr + v, iterator_A.get(), iterator_A.valid());
@@ -464,10 +486,9 @@ public:
for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
auto gmem_ptr = iterator_B.get();
int const kSrcBytes =
sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
int const kSrcBytes = sizeof_bits<typename IteratorB::Element>::value *
IteratorB::ThreadMap::kElementsPerAccess /
IteratorB::kAccessesPerVector / 8;
cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
dst_ptr + v, iterator_B.get(), iterator_B.valid());
@@ -480,18 +501,22 @@ public:
}
/// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching
/// the global fragments needed by the first kStages-1 threadblock mainloop iterations
/// the global fragments needed by the first kStages-1 threadblock mainloop
/// iterations
CUTLASS_DEVICE
void prologue(
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining
void prologue(IteratorA &iterator_A, ///< [in|out] iterator over A operand in
///< global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in
///< global memory
QuantArguments &
mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations) ///< [in|out] number of threadblock
///< mainloop iterations remaining
{
// Issue several complete stages
CUTLASS_PRAGMA_UNROLL
for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) {
for (int stage = 0; stage < Base::kStages - 1;
++stage, --gemm_k_iterations) {
// Disable global fetching if done with global fetch iterations
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
@@ -502,11 +527,14 @@ public:
// Async copy zipped B to shared memory.
copy_tiles_and_advance_per_stage_B(iterator_B);
// Async copy other quantized params to shared memory, local_scale, code_scale, code_zp, super_scale.
// Async copy other quantized params to shared memory, local_scale,
// code_scale, code_zp, super_scale.
if (stage == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(mma_quant_args, stage);
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<true>(
mma_quant_args, stage);
} else {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(
mma_quant_args, stage);
}
// Move to the next write stage
@@ -517,11 +545,12 @@ public:
cutlass::arch::cp_async_fence();
}
// Optionally clear the remaining stages of SMEM. This is a functional requirement for
// some kernels so that all accumulator elements outside the GEMM footprint are zero.
// Optionally clear the remaining stages of SMEM. This is a functional
// requirement for some kernels so that all accumulator elements outside the
// GEMM footprint are zero.
if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) {
/// Iterator to write threadblock-scoped tile of A operand to shared memory
/// Iterator to write threadblock-scoped tile of A operand to shared
/// memory
SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_);
typename IteratorA::AccessType zero_A;
@@ -531,7 +560,6 @@ public:
// Async Copy for operand A
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) {
typename IteratorA::AccessType *dst_ptr =
reinterpret_cast<typename IteratorA::AccessType *>(
last_smem_iterator_A.get());
@@ -545,7 +573,8 @@ public:
return;
}
/// Iterator to write threadblock-scoped tile of B operand to shared memory
/// Iterator to write threadblock-scoped tile of B operand to shared
/// memory
SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_);
typename IteratorB::AccessType zero_B;
@@ -555,7 +584,6 @@ public:
// Async Copy for operand B
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) {
typename IteratorB::AccessType *dst_ptr =
reinterpret_cast<typename IteratorB::AccessType *>(
last_smem_iterator_B.get());
@@ -569,9 +597,9 @@ public:
/// Wait until we have at least one completed global fetch stage
CUTLASS_DEVICE
void gmem_wait()
{
// Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed)
void gmem_wait() {
// Wait until we have at least one committed global fetch stage.
// (#uncommitted = Base::kStages - 1 - #committed)
cutlass::arch::cp_async_wait<Base::kStages - 2>();
__syncthreads();
}
@@ -579,25 +607,31 @@ public:
/// Perform a threadblock mainloop iteration of matrix multiply-accumulate
CUTLASS_DEVICE
void mac_loop_iter(
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop iterations remaining
int stage)
{
PipeState &pipe_state, ///< [in|out] loop-carried pipeline state
FragmentC &accum, ///< [in|out] destination accumulator tile
IteratorA
&iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB
&iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments
&mma_quant_args, ///< iterators for extra quant params for B
int &gemm_k_iterations, ///< [in|out] number of threadblock mainloop
///< iterations remaining
int stage) {
const int mma_stage = stage - Base::kStages + 1;
// Unroll the warp-level MMA tiles of a threadblock's mainloop iteration
CUTLASS_PRAGMA_UNROLL
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
int warp_k_compute_offset_B = warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
int warp_k_compute_offset_B =
warp_mma_k % Base::kWarpGemmIterationsPerLoadForB;
if (warp_k_compute_offset_B == Base::kWarpGemmIterationsPerLoadForB - 1) {
// Load the next warp-tile's B fragment from shared memory
this->warp_tile_iterator_B_.set_kgroup_index(((warp_mma_k + 1) % Base::kWarpGemmIterations) / Base::kWarpLoadIterationsForB);
this->warp_tile_iterator_B_.set_kgroup_index(
((warp_mma_k + 1) % Base::kWarpGemmIterations) /
Base::kWarpLoadIterationsForB);
this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_);
++this->warp_tile_iterator_B_;
}
@@ -608,28 +642,31 @@ public:
}
// Load the next warp-tile's A fragment from shared memory
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) %
Base::kWarpGemmIterations);
this->warp_tile_iterator_A_.load(
pipe_state.warp_frag_A_[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_;
// dequantizes next warp-tile
warp_dequantizer_.dequantize(pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1) : mma_stage) * Shape::kK,
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
warp_dequantizer_.dequantize(
pipe_state.warp_frag_local_scale_,
pipe_state.warp_frag_code_scale_,
pipe_state.warp_frag_code_zp_,
pipe_state.warp_frag_super_scale_,
pipe_state.warp_loaded_frag_B_,
pipe_state.warp_frag_B_[(warp_mma_k + 1) % 2],
((warp_mma_k == Base::kWarpGemmIterations - 1) ? (mma_stage + 1)
: mma_stage) *
Shape::kK,
(warp_mma_k + 1) % Base::kWarpGemmIterationsPerLoadForB);
// Execute the current warp-tile of MMA operations
if constexpr (Detail::kStagedAccumulation) {
warp_mma_(
pipe_state.tmp_accum_,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_
);
warp_mma_(pipe_state.tmp_accum_,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
pipe_state.tmp_accum_);
if (warp_mma_k == 0) {
plus<FragmentC> plus_accum;
@@ -637,11 +674,10 @@ public:
pipe_state.tmp_accum_.clear();
}
} else {
warp_mma_(
accum,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
accum);
warp_mma_(accum,
pipe_state.warp_frag_A_[warp_mma_k % 2],
pipe_state.warp_frag_B_[warp_mma_k % 2],
accum);
}
// Except for the last warp-tile, all warp-tiles issue their share of
@@ -654,22 +690,28 @@ public:
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
if (warp_mma_k == 0) {
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(mma_quant_args, stage);
quant_params_accessor_B_.copy_tiles_and_advance_per_stage<false>(
mma_quant_args, stage);
}
}
// The second-to-last warp-tile also:
// - performs the last warp-tile's share of global->shared fragment copies
// - performs the last warp-tile's share of global->shared fragment
// copies
// - moves to the next global fetch stage
if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
// Performs the last warp-tile's share of global->shared fragment copies
if constexpr (Detail::AsyncCopyIterationsPerStageA >= Base::kWarpGemmIterations) {
int group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
if constexpr (Detail::AsyncCopyIterationsPerStageA >=
Base::kWarpGemmIterations) {
int group_start_iteration_A =
(warp_mma_k + 1) * Detail::kAccessesPerGroupA;
copy_tiles_and_advance_A(iterator_A, group_start_iteration_A);
}
if constexpr (Detail::AsyncCopyIterationsPerStageB >= Base::kWarpGemmIterations) {
int group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
if constexpr (Detail::AsyncCopyIterationsPerStageB >=
Base::kWarpGemmIterations) {
int group_start_iteration_B =
(warp_mma_k + 1) * Detail::kAccessesPerGroupB;
copy_tiles_and_advance_B(iterator_B, group_start_iteration_B);
}
@@ -691,7 +733,8 @@ public:
--gemm_k_iterations;
iterator_A.clear_mask(gemm_k_iterations == 0);
iterator_B.clear_mask(gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args, gemm_k_iterations == 0);
quant_params_accessor_B_.clear_mask(mma_quant_args,
gemm_k_iterations == 0);
}
}
}
@@ -700,12 +743,13 @@ public:
/// multiply-accumulate. Assumes prologue has been initiated.
CUTLASS_DEVICE
void gemm_iters(
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args)
{
int gemm_k_iterations, ///< number of threadblock mainloop iterations
FragmentC &accum, ///< [in|out] accumulator tile
IteratorA
&iterator_A, ///< [in|out] iterator over A operand in global memory
IteratorB
&iterator_B, ///< [in|out] iterator over B operand in global memory
QuantArguments &mma_quant_args) {
PipeState pipe_state;
// Disable global fetching if done with global fetch iterations
@@ -748,14 +792,13 @@ public:
// Mainloop
CUTLASS_GEMM_LOOP
for (; gemm_k_iterations > (-Base::kStages + 1);) {
mac_loop_iter(
pipe_state,
accum,
iterator_A,
iterator_B,
mma_quant_args,
gemm_k_iterations,
stage);
mac_loop_iter(pipe_state,
accum,
iterator_A,
iterator_B,
mma_quant_args,
gemm_k_iterations,
stage);
stage += 1;
}
@@ -764,7 +807,8 @@ public:
accum = plus_accum(accum, pipe_state.tmp_accum_);
}
// Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop
// Commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
@@ -772,15 +816,16 @@ public:
/// Prepares the class for another prologue.
CUTLASS_DEVICE
void wind_down()
{
// Catch-up the smem-read iterator to the smem-write iterator (so this class can be reused for another tile's prologue)
void wind_down() {
// Catch-up the smem-read iterator to the smem-write iterator (so this class can
// be reused for another tile's prologue)
// First, increment remaining warp tiles to get to the next full stage. (Ideally we would
// just decrement one tile, but not all iterators implement --() decrement.)
#pragma unroll
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k)
{
// First, increment remaining warp tiles to get to the next full stage. (Ideally
// we would just decrement one tile, but not all iterators implement --()
// decrement.)
#pragma unroll
for (int warp_mma_k = 1; warp_mma_k < Base::kWarpGemmIterations;
++warp_mma_k) {
this->warp_tile_iterator_A_.set_kgroup_index(warp_mma_k);
this->warp_tile_iterator_B_.set_kgroup_index(warp_mma_k);
@@ -789,22 +834,24 @@ public:
}
smem_read_stage_idx_++;
// Then wrap back two full stages (one for the tile advancing we just did, and one to catch the write iterators)
static const int kStageIters = Policy::kPartitionsK * Base::kWarpGemmIterations;
if (smem_read_stage_idx_ > 1)
{
// Then wrap back two full stages (one for the tile advancing we just did,
// and one to catch the write iterators)
static const int kStageIters =
Policy::kPartitionsK * Base::kWarpGemmIterations;
if (smem_read_stage_idx_ > 1) {
this->warp_tile_iterator_A_.add_tile_offset({0, (-2 * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({(-2 * kStageIters), 0});
}
else
{
this->warp_tile_iterator_A_.add_tile_offset({0, ((Base::kStages - 2) * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset({((Base::kStages - 2) * kStageIters), 0});
} else {
this->warp_tile_iterator_A_.add_tile_offset(
{0, ((Base::kStages - 2) * kStageIters)});
this->warp_tile_iterator_B_.add_tile_offset(
{((Base::kStages - 2) * kStageIters), 0});
}
smem_read_stage_idx_ = smem_write_stage_idx_;
}
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to shared memory.
/// Perform a threadblock-scoped matrix multiply-accumulate, pre-load B to
/// shared memory.
CUTLASS_DEVICE
void operator()(
///< problem size of GEMM
@@ -819,8 +866,8 @@ public:
QuantArguments mma_quant_args,
///< initial value of accumulator
FragmentC const &src_accum) {
// Prologue (start fetching iterations of global fragments into shared memory)
// Prologue (start fetching iterations of global fragments into shared
// memory)
prologue(iterator_A, iterator_B, mma_quant_args, gemm_k_iterations);
// Wait until we have at least one completed global fetch stage
@@ -830,7 +877,8 @@ public:
accum = src_accum;
// Perform the MAC-iterations
gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
gemm_iters(
gemm_k_iterations, accum, iterator_A, iterator_B, mma_quant_args);
}
};
@@ -46,9 +46,10 @@ template <
/// Group size for quantization
int GroupSize_>
class Wint2ParamsAccessor {
public:
static_assert(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
public:
static_assert(platform::is_same<T, half_t>::value ||
platform::is_same<T, bfloat16_t>::value,
"T must be fp16 or bf16");
using ElementType = T;
using Shape = Shape_;
@@ -72,7 +73,7 @@ public:
using ElementLocalScale = typename IteratorLocalScale::Element;
using LayoutLocalScale = typename IteratorLocalScale::Layout;
static_assert(platform::is_same<ElementLocalScale, uint4b_t>::value,
"local_scale's type must be uint4b_t.");
"local_scale's type must be uint4b_t.");
using ElementCodeScaleZp = typename IteratorCodeScaleZp::Element;
using LayoutCodeScaleZp = typename IteratorCodeScaleZp::Layout;
@@ -80,24 +81,32 @@ public:
/// 2 uint4b_t values are stored in a single uint8_t
constexpr static int kStagesPerLocalScaleLoad = 2 * kGroupSize / Shape::kK;
constexpr static int kLocalScaleRows =
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn * sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn *
sizeof_bits<ElementLocalScale>::value / 8 / Shape::kN;
using SmemElement = uint8_t;
constexpr static int kSmemRows =
kLocalScaleRows * kStages + sizeof(ElementSuperScale) + sizeof(ElementCodeScaleZp) * 2;
constexpr static int kSmemRows = kLocalScaleRows * kStages +
sizeof(ElementSuperScale) +
sizeof(ElementCodeScaleZp) * 2;
constexpr static int kSmemColumns = Shape::kN;
using QuantParamsShape = MatrixShape<kSmemRows, kSmemColumns>;
constexpr static int kSuperScaleSmemOffset = 0;
constexpr static int kCodeScaleSmemOffset = kSmemColumns * sizeof(ElementSuperScale);
constexpr static int kCodeZpSmemOffset = kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kLocalScaleSmemOffset = kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kCodeScaleSmemOffset =
kSmemColumns * sizeof(ElementSuperScale);
constexpr static int kCodeZpSmemOffset =
kCodeScaleSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
constexpr static int kLocalScaleSmemOffset =
kCodeZpSmemOffset + kSmemColumns * sizeof(ElementCodeScaleZp);
/// TensorRef type for loading element from a tensor
using SuperTensorRef = cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
using LocalTensorRef = cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
using CodeTensorRef = cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
using SuperTensorRef =
cutlass::TensorRef<ElementSuperScale, LayoutSuperScale>;
using LocalTensorRef =
cutlass::TensorRef<ElementLocalScale, LayoutLocalScale>;
using CodeTensorRef =
cutlass::TensorRef<ElementCodeScaleZp, LayoutCodeScaleZp>;
struct Arguments {
IteratorSuperScale iterator_super_scale;
@@ -113,14 +122,14 @@ public:
IteratorCodeScaleZp iterator_code_scale,
IteratorCodeScaleZp iterator_code_zp,
int local_scale_pointer_offset)
: iterator_super_scale(iterator_super_scale),
iterator_local_scale(iterator_local_scale),
iterator_code_scale(iterator_code_scale),
iterator_code_zp(iterator_code_zp),
local_scale_pointer_offset(local_scale_pointer_offset) {}
: iterator_super_scale(iterator_super_scale),
iterator_local_scale(iterator_local_scale),
iterator_code_scale(iterator_code_scale),
iterator_code_zp(iterator_code_zp),
local_scale_pointer_offset(local_scale_pointer_offset) {}
};
private:
private:
//
// Data members
//
@@ -128,13 +137,17 @@ private:
/// Begin address of shared memory
uint8_t* smem_pointer_;
/// Iterator to write threadblock-scoped tile of super scale operand to shared memory
/// Iterator to write threadblock-scoped tile of super scale operand to shared
/// memory
SmemIteratorSuperScale smem_iterator_super_scale_;
/// Iterator to write threadblock-scoped tile of local scale operand to shared memory
/// Iterator to write threadblock-scoped tile of local scale operand to shared
/// memory
SmemIteratorLocalScale smem_iterator_local_scale_;
/// Iterator to write threadblock-scoped tile of code scale operand to shared memory
/// Iterator to write threadblock-scoped tile of code scale operand to shared
/// memory
SmemIteratorCodeScaleZp smem_iterator_code_scale_;
/// Iterator to write threadblock-scoped tile of code zp operand to shared memory
/// Iterator to write threadblock-scoped tile of code zp operand to shared
/// memory
SmemIteratorCodeScaleZp smem_iterator_code_zp_;
/// Shared memory write stage index
@@ -145,25 +158,29 @@ private:
CUTLASS_DEVICE
ElementSuperScale* get_super_scale_smem_ptr() {
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ + kSuperScaleSmemOffset);
return reinterpret_cast<ElementSuperScale*>(smem_pointer_ +
kSuperScaleSmemOffset);
}
CUTLASS_DEVICE
ElementLocalScale* get_local_scale_smem_ptr() {
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ + kLocalScaleSmemOffset);
return reinterpret_cast<ElementLocalScale*>(smem_pointer_ +
kLocalScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_scale_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeScaleSmemOffset);
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ +
kCodeScaleSmemOffset);
}
CUTLASS_DEVICE
ElementCodeScaleZp* get_code_zp_smem_ptr() {
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ + kCodeZpSmemOffset);
return reinterpret_cast<ElementCodeScaleZp*>(smem_pointer_ +
kCodeZpSmemOffset);
}
public:
public:
/// Construct from tensor references
CUTLASS_DEVICE
Wint2ParamsAccessor(
@@ -175,55 +192,74 @@ public:
int warp_idx,
///< ID of each thread within a warp
int lane_idx)
: smem_pointer_(smem_pointer),
smem_iterator_super_scale_(LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
get_super_scale_smem_ptr(), {1, IteratorSuperScale::Shape::kColumn}, thread_idx),
smem_iterator_local_scale_(LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
get_local_scale_smem_ptr(), {1, IteratorLocalScale::Shape::kColumn}, thread_idx),
smem_iterator_code_scale_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_scale_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_iterator_code_zp_(LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_zp_smem_ptr(), {1, IteratorCodeScaleZp::Shape::kColumn}, thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {}
: smem_pointer_(smem_pointer),
smem_iterator_super_scale_(
LayoutSuperScale(IteratorSuperScale::Shape::kColumn),
get_super_scale_smem_ptr(),
{1, IteratorSuperScale::Shape::kColumn},
thread_idx),
smem_iterator_local_scale_(
LayoutLocalScale(IteratorLocalScale::Shape::kColumn),
get_local_scale_smem_ptr(),
{1, IteratorLocalScale::Shape::kColumn},
thread_idx),
smem_iterator_code_scale_(
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_scale_smem_ptr(),
{1, IteratorCodeScaleZp::Shape::kColumn},
thread_idx),
smem_iterator_code_zp_(
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn),
get_code_zp_smem_ptr(),
{1, IteratorCodeScaleZp::Shape::kColumn},
thread_idx),
smem_write_stage_idx_(0),
smem_read_stage_idx_(0) {}
CUTLASS_DEVICE
SuperTensorRef super_scale_ref() {
return {get_super_scale_smem_ptr(), LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
return {get_super_scale_smem_ptr(),
LayoutSuperScale(IteratorSuperScale::Shape::kColumn)};
}
CUTLASS_DEVICE
LocalTensorRef local_scale_ref() {
return {get_local_scale_smem_ptr(), LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
return {get_local_scale_smem_ptr(),
LayoutLocalScale(IteratorLocalScale::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_scale_ref() {
return {get_code_scale_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
return {get_code_scale_smem_ptr(),
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
CUTLASS_DEVICE
CodeTensorRef code_zp_ref() {
return {get_code_zp_smem_ptr(), LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
return {get_code_zp_smem_ptr(),
LayoutCodeScaleZp(IteratorCodeScaleZp::Shape::kColumn)};
}
template <bool IsFirstStage>
CUTLASS_DEVICE
void copy_tiles_and_advance_per_stage(Arguments &quant_args, int stage) {
CUTLASS_DEVICE void copy_tiles_and_advance_per_stage(Arguments& quant_args,
int stage) {
if constexpr (IsFirstStage) {
// Load channel-wise super_scale to shared memory, which only needs to be done once.
// Load channel-wise super_scale to shared memory, which only needs to be
// done once.
typename IteratorSuperScale::Fragment tb_frag_super_scale;
tb_frag_super_scale.clear();
quant_args.iterator_super_scale.load(tb_frag_super_scale);
this->smem_iterator_super_scale_.store(tb_frag_super_scale);
// Load channel-wise code_scale to shared memory, which only needs to be done once.
// Load channel-wise code_scale to shared memory, which only needs to be
// done once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_scale;
tb_frag_code_scale.clear();
quant_args.iterator_code_scale.load(tb_frag_code_scale);
this->smem_iterator_code_scale_.store(tb_frag_code_scale);
// Load channel-wise code_zp to shared memory, which only needs to be done once.
// Load channel-wise code_zp to shared memory, which only needs to be done
// once.
typename IteratorCodeScaleZp::Fragment tb_frag_code_zp;
tb_frag_code_zp.clear();
quant_args.iterator_code_zp.load(tb_frag_code_zp);
@@ -231,20 +267,24 @@ public:
}
if ((stage % kStagesPerLocalScaleLoad) == 0) {
// Load group-wise local_scale to shared memory, which only needs to be done at each stage.
// Since 2 uint4b_t values of local_scale are saved in a single uint8_t, local_scale needs to be loaded once every two stages.
// Load group-wise local_scale to shared memory, which only needs to be
// done at each stage. Since 2 uint4b_t values of local_scale are saved in
// a single uint8_t, local_scale needs to be loaded once every two stages.
using AccessType = typename IteratorLocalScale::AccessType;
cutlass::arch::CacheOperation::Kind const kCacheOp = (sizeof_bits<AccessType>::value == 128)
? cutlass::arch::CacheOperation::Global : cutlass::arch::CacheOperation::Always;
cutlass::arch::CacheOperation::Kind const kCacheOp =
(sizeof_bits<AccessType>::value == 128)
? cutlass::arch::CacheOperation::Global
: cutlass::arch::CacheOperation::Always;
quant_args.iterator_local_scale.set_iteration_index(0);
this->smem_iterator_local_scale_.set_iteration_index(0);
// Async Copy for local_scale
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount; ++j) {
AccessType *dst_ptr =
reinterpret_cast<AccessType *>(this->smem_iterator_local_scale_.get());
for (int j = 0; j < IteratorLocalScale::ThreadMap::Iterations::kCount;
++j) {
AccessType* dst_ptr = reinterpret_cast<AccessType*>(
this->smem_iterator_local_scale_.get());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < IteratorLocalScale::kAccessesPerVector; ++v) {
@@ -255,8 +295,8 @@ public:
IteratorLocalScale::ThreadMap::kElementsPerAccess /
IteratorLocalScale::kAccessesPerVector / 8;
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
cutlass::arch::cp_async<kSrcBytes, kCacheOp>(
dst_ptr + v, gmem_ptr, quant_args.iterator_local_scale.valid());
}
++quant_args.iterator_local_scale;
}
@@ -265,13 +305,15 @@ public:
}
CUTLASS_DEVICE
void advance_smem_write_stage(Arguments &quant_args) {
void advance_smem_write_stage(Arguments& quant_args) {
if (smem_write_stage_idx_ % kStagesPerLocalScaleLoad == 0) {
// Advance global iterators
quant_args.iterator_local_scale.add_pointer_offset(quant_args.local_scale_pointer_offset);
quant_args.iterator_local_scale.add_pointer_offset(
quant_args.local_scale_pointer_offset);
// Advance shared iterators
int smem_pointer_offset = IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
int smem_pointer_offset =
IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(smem_pointer_offset);
}
@@ -280,7 +322,8 @@ public:
if (smem_write_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
// Wrap back around to the 'start' of the circular buffer in shared memory
int pointer_offset = - kStages * IteratorLocalScale::Shape::kRow * IteratorLocalScale::Shape::kColumn;
int pointer_offset = -kStages * IteratorLocalScale::Shape::kRow *
IteratorLocalScale::Shape::kColumn;
smem_iterator_local_scale_.add_pointer_offset(pointer_offset);
smem_write_stage_idx_ = 0;
}
@@ -298,14 +341,14 @@ public:
if (smem_read_stage_idx_ == kStagesPerLocalScaleLoad * kStages) {
smem_read_stage_idx_ = 0;
byte_offset = - (kStages - 1) * kLocalScaleRows * kSmemColumns;
byte_offset = -(kStages - 1) * kLocalScaleRows * kSmemColumns;
}
return byte_offset;
}
CUTLASS_DEVICE
int clear_mask(Arguments &quant_args, bool cond) {
int clear_mask(Arguments& quant_args, bool cond) {
quant_args.iterator_local_scale.clear_mask(cond);
}
};
@@ -29,18 +29,27 @@ namespace gemm {
namespace threadblock {
template <typename T, int N>
using UnzipArray = cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
using UnzipArray =
cutlass::AlignedArray<T, N, (N * cutlass::sizeof_bits<T>::value / 8)>;
template <typename T, WintQuantMethod QuantMethod, int TileRows,
int TileColumns, int NumThreads = 128>
template <typename T,
WintQuantMethod QuantMethod,
int TileRows,
int TileColumns,
int NumThreads = 128>
struct UnzipAndDequantFunctor {
__device__ void operator()(const T *in_ptr, const T *supper_scale_ptr,
T *out_ptr, const int64_t in_stride) {}
__device__ void operator()(const T *in_ptr,
const T *supper_scale_ptr,
T *out_ptr,
const int64_t in_stride) {}
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
TileColumns, NumThreads> {
struct UnzipAndDequantFunctor<T,
WintQuantMethod::kWeightOnlyInt25,
TileRows,
TileColumns,
NumThreads> {
using ZippedT = uint16_t;
using ScaleComputeT = float;
@@ -52,7 +61,8 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
static constexpr int32_t kLocalScaleMask = 0x1FFF;
static constexpr int32_t kBZP = 4;
__device__ inline T Compute(int32_t zipped_value, int32_t shift_bit,
__device__ inline T Compute(int32_t zipped_value,
int32_t shift_bit,
ScaleComputeT scale) {
int32_t shifted_value = (zipped_value >> shift_bit) & kWeightMask;
int32_t value = shifted_value - kBZP;
@@ -61,8 +71,10 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
return static_cast<T>(scaled_value);
}
__device__ void operator()(const uint16_t *in_ptr, const T *super_scale_ptr,
T *out_ptr, const int64_t in_stride) {
__device__ void operator()(const uint16_t *in_ptr,
const T *super_scale_ptr,
T *out_ptr,
const int64_t in_stride) {
int32_t shift_bits[7] = {13, 11, 9, 6, 4, 2, 0};
int tid = threadIdx.x;
@@ -111,8 +123,11 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt25, TileRows,
};
template <typename T, int TileRows, int TileColumns, int NumThreads>
struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
TileColumns, NumThreads> {
struct UnzipAndDequantFunctor<T,
WintQuantMethod::kWeightOnlyInt2,
TileRows,
TileColumns,
NumThreads> {
using ZippedT = uint8_t;
using ScaleComputeT = float;
@@ -129,9 +144,11 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
// super_scale [N] T
// code_scale, code_zp and super_scale
static constexpr int32_t kColumnWiseSmemBytes = (2 * sizeof(float) + sizeof(T)) * TileColumns;
static constexpr int32_t kColumnWiseSmemBytes =
(2 * sizeof(float) + sizeof(T)) * TileColumns;
// zipped weights and local_scale
static constexpr int32_t kZippedSmemBytes = (TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
static constexpr int32_t kZippedSmemBytes =
(TileRows / 4 + (TileRows + 127) / 128) * TileColumns;
struct Arguments {
uint8_t *weight_ptr;
@@ -140,14 +157,20 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
float *code_zp_ptr;
T *super_scale_ptr;
__device__ Arguments() : weight_ptr(nullptr), local_scale_ptr(nullptr), code_scale_ptr(nullptr), code_zp_ptr(nullptr), super_scale_ptr(nullptr) {}
__device__ Arguments()
: weight_ptr(nullptr),
local_scale_ptr(nullptr),
code_scale_ptr(nullptr),
code_zp_ptr(nullptr),
super_scale_ptr(nullptr) {}
__device__ explicit Arguments(uint8_t *smem_ptr) {
SetZippedPtrs(smem_ptr);
SetColumnWisePtrs(smem_ptr + kZippedSmemBytes);
}
__device__ Arguments(uint8_t *zipped_smem_ptr, uint8_t *column_wise_smem_ptr) {
__device__ Arguments(uint8_t *zipped_smem_ptr,
uint8_t *column_wise_smem_ptr) {
SetZippedPtrs(zipped_smem_ptr);
SetColumnWisePtrs(column_wise_smem_ptr);
}
@@ -159,15 +182,21 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
__device__ void SetColumnWisePtrs(uint8_t *column_wise_smem_ptr) {
code_scale_ptr = reinterpret_cast<float *>(column_wise_smem_ptr);
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr + sizeof(float) * TileColumns);
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr + 2 * sizeof(float) * TileColumns);
code_zp_ptr = reinterpret_cast<float *>(column_wise_smem_ptr +
sizeof(float) * TileColumns);
super_scale_ptr = reinterpret_cast<T *>(column_wise_smem_ptr +
2 * sizeof(float) * TileColumns);
}
};
__device__ void Load(const uint8_t *g_weight_ptr, const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr, const float *g_code_zp_ptr,
__device__ void Load(const uint8_t *g_weight_ptr,
const uint8_t *g_local_scale_ptr,
const float *g_code_scale_ptr,
const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
Arguments *args,
const int64_t in_stride,
bool need_preload) {
int tid = threadIdx.x;
#pragma unroll
@@ -186,7 +215,8 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
#pragma unroll
for (int ls_row_id = 0; ls_row_id < TileRows / 128; ++ls_row_id) {
int local_scale_offset = ls_row_id * in_stride + col;
args->local_scale_ptr[ls_row_id * TileColumns + col] = g_local_scale_ptr[local_scale_offset];
args->local_scale_ptr[ls_row_id * TileColumns + col] =
g_local_scale_ptr[local_scale_offset];
}
#pragma unroll
@@ -205,10 +235,12 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
const float *g_code_scale_ptr,
const float *g_code_zp_ptr,
const T *g_super_scale_ptr,
Arguments *args, const int64_t in_stride, bool need_preload) {
Arguments *args,
const int64_t in_stride,
bool need_preload) {
int tid = threadIdx.x;
constexpr int kBytesPerThread = 16; // 16B per thread
constexpr int kBytesPerThread = 16; // 16B per thread
constexpr int weight_size = TileRows / 4 * TileColumns;
constexpr int local_scale_size = (TileRows + 127) / 128 * TileColumns;
@@ -216,87 +248,130 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
constexpr int code_zp_size = sizeof(float) * TileColumns;
constexpr int super_scale_size = sizeof(T) * TileColumns;
constexpr int total_size = weight_size + local_scale_size + code_scale_size + code_zp_size + super_scale_size;
constexpr int total_size = weight_size + local_scale_size +
code_scale_size + code_zp_size +
super_scale_size;
constexpr int total_tasks = total_size / kBytesPerThread;
constexpr int cur_num_threads = total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
constexpr int cur_num_threads =
total_tasks / ((total_tasks + NumThreads - 1) / NumThreads);
constexpr int weight_threads = weight_size * cur_num_threads / total_size;
constexpr int local_scale_threads = local_scale_size * cur_num_threads / total_size;
constexpr int code_scale_threads = code_scale_size * cur_num_threads / total_size;
constexpr int local_scale_threads =
local_scale_size * cur_num_threads / total_size;
constexpr int code_scale_threads =
code_scale_size * cur_num_threads / total_size;
constexpr int code_zp_threads = code_zp_size * cur_num_threads / total_size;
constexpr int super_scale_threads = super_scale_size * cur_num_threads / total_size;
constexpr int super_scale_threads =
super_scale_size * cur_num_threads / total_size;
static_assert(TileColumns % weight_threads == 0,
"TileColumns must be divisible by weight_threads to ensure correct thread mapping.");
"TileColumns must be divisible by weight_threads to ensure "
"correct thread mapping.");
static_assert(TileColumns % local_scale_threads == 0,
"TileColumns must be divisible by local_scale_threads to ensure correct thread mapping.");
"TileColumns must be divisible by local_scale_threads to "
"ensure correct thread mapping.");
if (tid < weight_threads) {
constexpr int weight_per_thread_size = weight_size / weight_threads;
constexpr int kIterations = (weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
constexpr int kIterations =
(weight_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
int z_offset = (tid * weight_per_thread_size + i * kBytesPerThread);
int g_offset =
z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread,
cutlass::arch::CacheOperation::Global>(
args->weight_ptr + z_offset, g_weight_ptr + g_offset, true);
}
} else if (tid < weight_threads + local_scale_threads) {
constexpr int start_thread_id = weight_threads;
constexpr int local_scale_per_thread_size = local_scale_size / local_scale_threads;
constexpr int kIterations = (local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
constexpr int local_scale_per_thread_size =
local_scale_size / local_scale_threads;
constexpr int kIterations =
(local_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size + i * kBytesPerThread;
int g_offset = z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->local_scale_ptr + z_offset, g_local_scale_ptr + g_offset, true);
int z_offset = (tid - start_thread_id) * local_scale_per_thread_size +
i * kBytesPerThread;
int g_offset =
z_offset / TileColumns * in_stride + z_offset % TileColumns;
cutlass::arch::cp_async<kBytesPerThread,
cutlass::arch::CacheOperation::Global>(
args->local_scale_ptr + z_offset,
g_local_scale_ptr + g_offset,
true);
}
} else if (need_preload) {
if (tid < weight_threads + local_scale_threads + code_scale_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads;
constexpr int code_scale_per_thread_size = code_scale_size / code_scale_threads;
constexpr int kIterations = (code_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
constexpr int code_scale_per_thread_size =
code_scale_size / code_scale_threads;
constexpr int kIterations =
(code_scale_per_thread_size + kBytesPerThread - 1) /
kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_scale_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
int offset = ((tid - start_thread_id) * code_scale_per_thread_size +
i * kBytesPerThread) /
sizeof(float);
cutlass::arch::cp_async<kBytesPerThread,
cutlass::arch::CacheOperation::Global>(
args->code_scale_ptr + offset, g_code_scale_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads;
} else if (tid < weight_threads + local_scale_threads +
code_scale_threads + code_zp_threads) {
constexpr int start_thread_id =
weight_threads + local_scale_threads + code_scale_threads;
constexpr int code_zp_per_thread_size = code_zp_size / code_zp_threads;
constexpr int kIterations = (code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
constexpr int kIterations =
(code_zp_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * code_zp_per_thread_size + i * kBytesPerThread) / sizeof(float);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
int offset = ((tid - start_thread_id) * code_zp_per_thread_size +
i * kBytesPerThread) /
sizeof(float);
cutlass::arch::cp_async<kBytesPerThread,
cutlass::arch::CacheOperation::Global>(
args->code_zp_ptr + offset, g_code_zp_ptr + offset, true);
}
} else if (tid < weight_threads + local_scale_threads + code_scale_threads + code_zp_threads + super_scale_threads) {
} else if (tid < weight_threads + local_scale_threads +
code_scale_threads + code_zp_threads +
super_scale_threads) {
if (g_super_scale_ptr) {
constexpr int start_thread_id = weight_threads + local_scale_threads + code_scale_threads + code_zp_threads;
constexpr int super_scale_per_thread_size = super_scale_size / super_scale_threads;
constexpr int kIterations = (super_scale_per_thread_size + kBytesPerThread - 1) / kBytesPerThread;
constexpr int start_thread_id = weight_threads + local_scale_threads +
code_scale_threads + code_zp_threads;
constexpr int super_scale_per_thread_size =
super_scale_size / super_scale_threads;
constexpr int kIterations =
(super_scale_per_thread_size + kBytesPerThread - 1) /
kBytesPerThread;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kIterations; ++i) {
int offset = ((tid - start_thread_id) * super_scale_per_thread_size + i * kBytesPerThread) / sizeof(T);
cutlass::arch::cp_async<kBytesPerThread, cutlass::arch::CacheOperation::Global>(
args->super_scale_ptr + offset, g_super_scale_ptr + offset, true);
int offset =
((tid - start_thread_id) * super_scale_per_thread_size +
i * kBytesPerThread) /
sizeof(T);
cutlass::arch::cp_async<kBytesPerThread,
cutlass::arch::CacheOperation::Global>(
args->super_scale_ptr + offset,
g_super_scale_ptr + offset,
true);
}
}
}
}
}
__device__ void Compute(const Arguments &args, T *out_ptr,
__device__ void Compute(const Arguments &args,
T *out_ptr,
const int64_t block_start_row) {
int32_t shift_bits[4] = {9, 6, 3, 0};
@@ -333,9 +408,9 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
#pragma unroll
for (int zipped_row = 0; zipped_row < 16; ++zipped_row) {
int32_t decode_value =
static_cast<int32_t>(floor(zipped_value[zipped_row] * code_scale + code_zp +
static_cast<ScaleComputeT>(0.5)));
int32_t decode_value = static_cast<int32_t>(
floor(zipped_value[zipped_row] * code_scale + code_zp +
static_cast<ScaleComputeT>(0.5)));
int row = group_id * 64 + zipped_row * 4;
@@ -355,14 +430,17 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
__syncthreads();
}
__device__ void ComputeVectorized(const Arguments &args, T *out_ptr,
__device__ void ComputeVectorized(const Arguments &args,
T *out_ptr,
const int64_t block_start_row) {
constexpr int kNumWeightsPerThread = TileRows * TileColumns / (4 * NumThreads);
constexpr int kNumWeightsPerThread =
TileRows * TileColumns / (4 * NumThreads);
constexpr int N = (kNumWeightsPerThread >= 32) ? 4 : 2;
constexpr int RowStride = NumThreads * N / TileColumns;
constexpr int kNumIters = kNumWeightsPerThread / N;
static_assert(N * NumThreads >= TileColumns, "N * NumThreads should be no less than TileColumns.");
static_assert(N * NumThreads >= TileColumns,
"N * NumThreads should be no less than TileColumns.");
constexpr ScaleComputeT decode_value_zp = static_cast<ScaleComputeT>(0.5);
@@ -373,19 +451,22 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
static_assert(TileRows <= 128, "TileRows is expected to no more than 128.");
UnzipArray<uint8_t, N> local_scales =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr + begin_col_id);
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.local_scale_ptr +
begin_col_id);
UnzipArray<uint8_t, N> zipped_values[2];
int zipped_offset = begin_row_id * TileColumns + begin_col_id;
zipped_values[0] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
zipped_values[0] = *reinterpret_cast<const UnzipArray<uint8_t, N> *>(
args.weight_ptr + zipped_offset);
UnzipArray<T, N> super_scales =
*reinterpret_cast<const UnzipArray<T, N> *>(args.super_scale_ptr + begin_col_id);
UnzipArray<T, N> super_scales = *reinterpret_cast<const UnzipArray<T, N> *>(
args.super_scale_ptr + begin_col_id);
UnzipArray<float, N> code_scales =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr + begin_col_id);
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_scale_ptr +
begin_col_id);
UnzipArray<float, N> code_zps =
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr + begin_col_id);
*reinterpret_cast<const UnzipArray<float, N> *>(args.code_zp_ptr +
begin_col_id);
// special for TileRows = 64
int local_scale_shift = (((block_start_row / 64) + 1) & 1) * 4;
@@ -394,9 +475,10 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t shifted_local_scale =
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) & kLocalScaleMask;
scales[i] =
static_cast<ScaleComputeT>(shifted_local_scale) * static_cast<ScaleComputeT>(super_scales[i]);
(static_cast<int32_t>(local_scales[i]) >> local_scale_shift) &
kLocalScaleMask;
scales[i] = static_cast<ScaleComputeT>(shifted_local_scale) *
static_cast<ScaleComputeT>(super_scales[i]);
}
#pragma unroll
@@ -405,26 +487,33 @@ struct UnzipAndDequantFunctor<T, WintQuantMethod::kWeightOnlyInt2, TileRows,
int row = zipped_row * 4;
if (iter_id < kNumIters - 1) {
int zipped_offset = (zipped_row + RowStride) * TileColumns + begin_col_id;
int zipped_offset =
(zipped_row + RowStride) * TileColumns + begin_col_id;
zipped_values[(iter_id + 1) & 1] =
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr + zipped_offset);
*reinterpret_cast<const UnzipArray<uint8_t, N> *>(args.weight_ptr +
zipped_offset);
}
UnzipArray<T, N> outs[4];
#pragma unroll
for (int i = 0; i < N; ++i) {
int32_t decode_value =
static_cast<int32_t>(floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) * code_scales[i]
+ code_zps[i] + decode_value_zp));
int32_t decode_value = static_cast<int32_t>(
floor(static_cast<ScaleComputeT>(zipped_values[iter_id & 1][i]) *
code_scales[i] +
code_zps[i] + decode_value_zp));
ScaleComputeT value_3 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
ScaleComputeT value_3 =
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_2 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
ScaleComputeT value_2 =
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_1 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
ScaleComputeT value_1 =
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
decode_value >>= 3;
ScaleComputeT value_0 = static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
ScaleComputeT value_0 =
static_cast<ScaleComputeT>((decode_value & kWeightMask) - kBZP);
outs[0][i] = static_cast<T>(scales[i] * value_0);
outs[1][i] = static_cast<T>(scales[i] * value_1);
outs[2][i] = static_cast<T>(scales[i] * value_2);